diff --git a/.cicd/agent-kaniko.yaml b/.cicd/agent-kaniko.yaml deleted file mode 100644 index 2af6ade41..000000000 --- a/.cicd/agent-kaniko.yaml +++ /dev/null @@ -1,18 +0,0 @@ -apiVersion: v1 -kind: Pod -spec: - containers: - - name: kaniko - image: gcr.io/kaniko-project/executor:debug - command: [/busybox/cat] - tty: true - volumeMounts: - - name: kaniko-secret - mountPath: /secret - env: - - name: GOOGLE_APPLICATION_CREDENTIALS - value: /secret/kaniko-secret.json - volumes: - - name: kaniko-secret - secret: - secretName: kaniko-secret diff --git a/.cicd/agent-python.yaml b/.cicd/agent-python.yaml deleted file mode 100644 index c0abd5a32..000000000 --- a/.cicd/agent-python.yaml +++ /dev/null @@ -1,22 +0,0 @@ -apiVersion: v1 -kind: Pod -spec: - containers: - - name: python - image: python:3.7 - command: [cat] - tty: true - volumeMounts: - - { name: tmp, mountPath: /tmp } - - { name: docker, mountPath: /var/run/docker.sock } - - { name: kaniko-secret, mountPath: /secret } - env: - - { name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/kaniko-secret.json } - volumes: - - name: tmp - hostPath: { path: /tmp, type: Directory } - - name: docker - hostPath: { path: /var/run/docker.sock, type: File } - - name: kaniko-secret - secret: - secretName: kaniko-secret diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..69a7abe71 --- /dev/null +++ b/.flake8 @@ -0,0 +1,9 @@ +[flake8] +max-line-length = 120 +ignore = E402, W504, F403, F405 +exclude = backend/substrapp/migrations/*, + backend/substrapp/tests/assets.py, + backend/backend/settings/*, + backend/node/migrations/*, + backend/medias/*, + .env, .venv diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3b61902ed..19e65abf8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,4 +2,4 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @camillemarini @GuillaumeCisco +* @inalgnu @GuillaumeCisco @Kelvin-M @samlesu diff --git a/.gitignore b/.gitignore index fab7f44d5..7b17fc603 100644 --- a/.gitignore +++ b/.gitignore @@ -106,29 +106,13 @@ venv.bak/ # idea files .idea -#secret files -SECRET - -# conf from substra-network -substrabac/substrapp/**/conf - -**/medias/* - -# ledger binary files -bin +# vscode files +.vscode/ -dryrun - -# test files -substrabac/substrabac/description.md -substrabac/substrabac/metrics.py +# secret files +SECRET -# docker database dir -postgres-data +# substra-backend files docker/docker-compose-dynamic.yaml - -network.json -benin_malin -# file for testing data creation with django command -data.json - +**/medias/* +backend/node/nodes/* diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..25664d1ff --- /dev/null +++ b/.travis.yml @@ -0,0 +1,46 @@ +language: python + +python: + - '3.7' + +cache: pip + +branches: + only: + - master + - dev + +env: + - DJANGO_SETTINGS_MODULE=backend.settings.test + +addons: + snaps: + - name: helm + confinement: classic + channel: stable + +before_script: + - helm init --client-only + +install: pip install -r backend/requirements.txt + +script: + - helm lint charts/substra-backend + - cd backend + - pip install flake8 + - flake8 + - coverage run manage.py test + +after_script: + - 'if ! git diff --quiet --exit-code $TRAVIS_COMMIT_RANGE charts; then CHART_CHANGED="true"; else CHART_CHANGED="false"; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then helm dep update charts/substra-backend; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then helm package charts/substra-backend; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git config --global user.email "travis@travis-ci.org"; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git config --global user.name "Travis CI"; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git clone https://${GH_TOKEN}@github.com/SubstraFoundation/charts.git substra-charts; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then mv hlf-k8s-* substra-charts/; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then cd substra-charts; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then helm repo index .; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git add .; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git commit --message "Travis build: $TRAVIS_BUILD_NUMBER"; fi' + - 'if [ "$CHART_CHANGED" == "true" -a "$TRAVIS_PULL_REQUEST" == "false" ]; then git push --quiet --set-upstream origin master; fi' diff --git a/Jenkinsfile b/Jenkinsfile deleted file mode 100644 index b177c9826..000000000 --- a/Jenkinsfile +++ /dev/null @@ -1,110 +0,0 @@ -pipeline { - options { - timestamps () - timeout(time: 1, unit: 'HOURS') - buildDiscarder(logRotator(numToKeepStr: '5')) - } - - agent none - - stages { - stage('Abort previous builds'){ - steps { - milestone(Integer.parseInt(env.BUILD_ID)-1) - milestone(Integer.parseInt(env.BUILD_ID)) - } - } - - stage('Test & Build') { - parallel { - stage('Test') { - agent { - kubernetes { - label 'substrabac-test' - defaultContainer 'python' - yamlFile '.cicd/agent-python.yaml' - } - } - - steps { - sh "apt update" - sh "apt install curl && mkdir -p /tmp/download && curl -L https://download.docker.com/linux/static/stable/x86_64/docker-18.06.3-ce.tgz | tar -xz -C /tmp/download && mv /tmp/download/docker/docker /usr/local/bin/" - sh "docker login -u _json_key --password-stdin https://eu.gcr.io/substra-208412/ < /secret/kaniko-secret.json" - sh "apt install -y python3-pip python3-dev build-essential gfortran musl-dev postgresql-contrib git curl netcat" - - dir("substrabac") { - sh "pip install -r requirements.txt" - sh "DJANGO_SETTINGS_MODULE=substrabac.settings.test coverage run manage.py test" - sh "coverage report" - sh "coverage html" - } - } - - post { - success { - publishHTML target: [ - allowMissing: false, - alwaysLinkToLastBuild: false, - keepAll: true, - reportDir: 'substrabac/htmlcov', - reportFiles: 'index.html', - reportName: 'Coverage Report' - ] - } - } - } - - stage('Build substrabac') { - agent { - kubernetes { - label 'substrabac-kaniko-substrabac' - yamlFile '.cicd/agent-kaniko.yaml' - } - } - - steps { - container(name:'kaniko', shell:'/busybox/sh') { - sh '''#!/busybox/sh - /kaniko/executor -f `pwd`/docker/substrabac/Dockerfile -c `pwd` -d "eu.gcr.io/substra-208412/substrabac:$GIT_COMMIT" - ''' - } - } - } - - stage('Build celerybeat') { - agent { - kubernetes { - label 'substrabac-kaniko-celerybeat' - yamlFile '.cicd/agent-kaniko.yaml' - } - } - - steps { - container(name:'kaniko', shell:'/busybox/sh') { - sh '''#!/busybox/sh - /kaniko/executor -f `pwd`/docker/celerybeat/Dockerfile -c `pwd` -d "eu.gcr.io/substra-208412/celerybeat:$GIT_COMMIT" - ''' - } - } - } - - stage('Build celeryworker') { - agent { - kubernetes { - label 'substrabac-kaniko-celeryworker' - yamlFile '.cicd/agent-kaniko.yaml' - } - } - - steps { - container(name:'kaniko', shell:'/busybox/sh') { - sh '''#!/busybox/sh - /kaniko/executor -f `pwd`/docker/celeryworker/Dockerfile -c `pwd` -d "eu.gcr.io/substra-208412/celeryworker:$GIT_COMMIT" - ''' - } - } - } - } - } - } -} diff --git a/README.md b/README.md index 4be9757c2..f3aca1030 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ -# Substrabac +# Substra-Backend Backend of the Substra platform ## Version -You will note substrabac use a versioned REST API with the header protocol. +You will note substra-backend use a versioned REST API with the header protocol. Current is `0.0`. ## Getting started 1: Prepare the django app 1. Clone the repo: ``` - git clone https://github.com/SubstraFoundation/substrabac + git clone https://github.com/SubstraFoundation/substra-backend ``` 2. :warning: Do this step only if your environment development is on linux. On linux systems, all the docker instances create files with `root` permissions. @@ -51,7 +51,7 @@ guillaume:165536:65536 ``` The first line should be added with the docker group (999 in my case). -Final step is to redownload all the dockers image, go in the substra-network project and rerun the `./bootstrap.sh` script. +Final step is to redownload all the dockers image, go in the hlf-k8s project and rerun the `./bootstrap.sh` script. Do not forget to build the substra-model image as described in the step 9 of this tutorial. 3. Install dependencies (might be useful to create a virtual environment before, eg using virtualenv and virtualenvwrapper): @@ -64,14 +64,14 @@ Do not forget to build the substra-model image as described in the step 9 of thi ```shell $> sudo su postgres $> psql - $ CREATE USER substrabac WITH PASSWORD 'substrabac' CREATEDB CREATEROLE SUPERUSER; + $ CREATE USER backend WITH PASSWORD 'backend' CREATEDB CREATEROLE SUPERUSER; ``` 6. Create two databases for both orgs: owkin and chu-nantes. A shell script is available, do not hesitate to run it. -It will drop the databases if they are already created, then create them and grant all privileges to your main user substrabac. +It will drop the databases if they are already created, then create them and grant all privileges to your main user backend. (If this is the first time you create the databases, you will see some warnings which are pointless): ```shell - $> ./substrabac/scripts/recreate_db.sh + $> ./scripts/recreate_db.sh ``` 7. We will populate data: @@ -79,56 +79,53 @@ It will drop the databases if they are already created, then create them and gra - With django migrations ```shell -SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 python substrabac/manage.py migrate --settings=substrabac.settings.dev -SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 python substrabac/manage.py migrate --settings=substrabac.settings.dev +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 python backend/manage.py migrate --settings=backend.settings.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 python backend/manage.py migrate --settings=backend.settings.dev ``` -###### With fixtures (fixtures container has been run from substra-network, old behavior for testing) +###### With fixtures (fixtures container has been run from hlf-k8s, old behavior for testing) data in fixtures are relative to the data already set in the ledger if the fixtures container instance succeeded Two solutions: - With django migrations + load data ```shell -SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 python substrabac/manage.py migrate --settings=substrabac.settings.dev -SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 python substrabac/manage.py migrate --settings=substrabac.settings.dev -SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 python substrabac/manage.py loaddata ./fixtures/data_owkin.json --settings=substrabac.settings.dev -SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 python substrabac/manage.py loaddata ./fixtures/data_chu-nantes.json --settings=substrabac.settings.dev +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 python backend/manage.py migrate --settings=backend.settings.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 python backend/manage.py migrate --settings=backend.settings.dev +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 python backend/manage.py loaddata ./fixtures/data_owkin.json --settings=backend.settings.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 python backend/manage.py loaddata ./fixtures/data_chu-nantes.json --settings=backend.settings.dev ``` - From dumps: ```shell - $> ./substrabac/scripts/populate_db.sh + $> ./scripts/populate_db.sh ``` If you don't want to replicate the data in the ledger, simply run the django migrations. Populate media files ```shell - $> ./substrabac/scripts/load_fixtures.sh + $> ./scripts/load_fixtures.sh ``` It will clean the `medias` folders and create the `owkin` and `chu-nantes` folders in the `medias` folder. 8. Optional: Create a superuser in your databases: -``` -SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 python substrabac/manage.py createsuperuser --settings=substrabac.settings.dev -SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 python substrabac/manage.py createsuperuser --settings=substrabac.settings.dev +```shell +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 ./backend/manage.py createsuperuser --settings=backend.settings.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 ./backend/manage.py createsuperuser --settings=backend.settings.dev ``` 9. Build the substra-model docker image: -Clone the following git repo https://github.com/SubstraFoundation/substratools and build the docker image -``` +Clone the following git repo https://github.com/SubstraFoundation/substra-tools and build the docker image +```shell docker build -t substra-model . ``` ## Getting started 2: Linking the app with Hyperledger Fabric -### Get Fabric binaries -Run `./boostrap.sh` +### Make the hlf-k8s available to the app -### Make the subtra-network available to the app - -[See here](https://github.com/SubstraFoundation/substra-network#network). +[See here](https://github.com/SubstraFoundation/hlf-k8s#network). ### Install rabbitmq @@ -138,33 +135,56 @@ sudo apt-get install rabbitmq-server ### Launch celery workers/scheduler and celery beat -Execute this command in the `substrabac/substrabac` folder. +Execute this command in the `backend/backend` folder. Note the use of the development settings. ```shell -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 celery -E -A substrabac worker -l info -B -n owkin -Q owkin,scheduler,celery --hostname owkin.scheduler -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 celery -E -A substrabac worker -l info -B -n owkin -Q owkin,owkin.worker,celery --hostname owkin.worker -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 celery -E -A substrabac worker -l info -B -n owkin -Q owkin,owkin.dryrunner,celery --hostname owkin.dryrunner -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 celery -E -A substrabac worker -l info -B -n chunantes -Q chu-nantes,scheduler,celery --hostname chu-nantes.scheduler -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 celery -E -A substrabac worker -l info -B -n chunantes -Q chu-nantes,chu-nantes.worker,celery --hostname chu-nantes.worker -DJANGO_SETTINGS_MODULE=substrabac.settings.dev SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 celery -E -A substrabac worker -l info -B -n chunantes -Q chu-nantes,chu-nantes.dryrunner,celery --hostname chu-nantes.dryrunner -DJANGO_SETTINGS_MODULE=substrabac.settings.common celery -A substrabac beat -l info +DJANGO_SETTINGS_MODULE=backend.settings.dev BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 BACKEND_PEER_PORT_EXTERNAL=9051 celery -E -A backend worker -l info -B -n owkin -Q owkin,scheduler,celery --hostname owkin.scheduler +DJANGO_SETTINGS_MODULE=backend.settings.dev BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 BACKEND_PEER_PORT_EXTERNAL=9051 celery -E -A backend worker -l info -B -n owkin -Q owkin,owkin.worker,celery --hostname owkin.worker +DJANGO_SETTINGS_MODULE=backend.settings.dev BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 BACKEND_PEER_PORT_EXTERNAL=7051 celery -E -A backend worker -l info -B -n chunantes -Q chu-nantes,scheduler,celery --hostname chu-nantes.scheduler +DJANGO_SETTINGS_MODULE=backend.settings.dev BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 BACKEND_PEER_PORT_EXTERNAL=7051 celery -E -A backend worker -l info -B -n chunantes -Q chu-nantes,chu-nantes.worker,celery --hostname chu-nantes.worker +DJANGO_SETTINGS_MODULE=backend.settings.common celery -A backend beat -l info ``` ## Launch the servers -Go in the `substrabac` folder and run the server locally: - ``` - SUBSTRABAC_ORG=owkin SUBSTRABAC_DEFAULT_PORT=8000 python manage.py runserver 8000 --settings=substrabac.settings.dev - SUBSTRABAC_ORG=chu-nantes SUBSTRABAC_DEFAULT_PORT=8001 python manage.py runserver 8001 --settings=substrabac.settings.dev - ``` +Go in the `backend` folder and run the server locally: +:warning:

Be very careful, --settings is different here, `server` is needed.

+ +```shell +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 BACKEND_PEER_PORT_EXTERNAL=9051 ./manage.py runserver 8000 --settings=backend.settings.server.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 BACKEND_PEER_PORT_EXTERNAL=7051 ./manage.py runserver 8001 --settings=backend.settings.server.dev +``` + +## Generate nodes authentication + +For working with node to node authentication, you need to generate and then load some fixtures +```shell +python ./backend/node/generate_nodes.py +BACKEND_ORG=owkin BACKEND_DEFAULT_PORT=8000 ./manage.py init_nodes ./backend/node/nodes/owkinMSP.json --settings=backend.settings.dev +BACKEND_ORG=chu-nantes BACKEND_DEFAULT_PORT=8001 ./manage.py init_nodes ./backend/node/nodes/chu-nantesMSP.json --settings=backend.settings.dev +``` +## Create a default user -## Test with unit and functionnal tests +A django admin command is available for registering a user: +```shell +./manage.py add_user $USERNAME $PASSWORD +``` +The populate.py file will use for each organization credentials `substra/p@$swr0d44` for connection. +Create these users with: + +```shell +BACKEND_ORG=owkin ./backend/manage.py add_user substra 'p@$swr0d44' --settings=backend.settings.dev +BACKEND_ORG=chu-nantes ./backend/manage.py add_user substra 'p@$swr0d44' --settings=backend.settings.dev ``` - DJANGO_SETTINGS_MODULE=substrabac.settings.test coverage run manage.py test + +## Test with unit and functional tests + +```shell + DJANGO_SETTINGS_MODULE=backend.settings.test coverage run manage.py test coverage report # For shell report coverage html # For HTML report ``` @@ -183,8 +203,8 @@ When you want to re-run the testing process: - Stop all your services and containers. - Rerun `recreate_db.sh` and `clean_media.sh` scripts. - Run the django migrations. -- Relaunch your substra-network. -- Run the owkin and chunantes substrabac servers. +- Relaunch your hlf-k8s network. +- Run the owkin and chunantes substra-backend servers. - Run celery beat and celery owkin and chu-nantes. - Run the `populate.py` python script. @@ -213,22 +233,36 @@ Now you can reach `http://localhost:8000/` and `http://localhost:8001/` :tada: ## Launching with docker -As for substra-network, you can launch all the services in docker containers.| -First, build the images: +As for hlf-k8s, you can launch all the services in docker containers. + +First, Make sure you've generated some nodes artifacts: +```bash +$> python ./backend/node/generate_nodes.py +``` + +Then, build the images: ```bash $> sh build-docker-images.sh ``` -Then, go to the`docker` dir and run `start.py`: + +Then, go to the `docker` dir and run `start.py` (`-d` means `dev` settings): ```bash -$> python3 start.py +$> python start.py -d --no-backup ``` Check your services are correctly started with `docker ps -a`. +## Expiry token period + +Two global environment variables `ACCESS_TOKEN_LIFETIME` and `EXPIRY_TOKEN_LIFETIME` expressed in minutes can be set for dealing with expiry token period. +The first one `ACCESS_TOKEN_LIFETIME` deals with JWT Authentication. +THe second one `EXPIRY_TOKEN_LIFETIME` deals with simple token expiration. +By default, set to 24*60 min i.e 24h. + ## Testing fabric-sdk-py A directory named `fabric-sdk-py_tests` is available to the root of this project. -If you launch a substra-network setup, you will be able to play with theses tests. +If you launch a hlf-k8s setup, you will be able to play with theses tests. For `fabric-sdk-py-query-invoke.py`, be sure to have run the `generateNetworkFile.py` script for producing the network.json file needed. ## Miscellaneous @@ -248,15 +282,13 @@ Use these configurations for easier debugging and productivity: ![](assets/server_chunantes.png) ![](assets/celery owkin worker.png) ![](assets/celery owkin scheduler.png) -![](assets/celery owkin dryrunner.png) ![](assets/celery chunantes worker.png) ![](assets/celery chunantes scheduler.png) -![](assets/celery chunantes dryrunner.png) ![](assets/celery_beat.png) Do not hesitate to put breakpoints in your code. Even with periodic celery tasks and hit the `bug` button for launching your pre configurations. -You can even access directly to the databases (password is `substrabac` as described in the beginning of this document): +You can even access directly to the databases (password is `backend` as described in the beginning of this document): ![](assets/database_owkin.png) ![](assets/database_owkin_challenges.png) diff --git a/assets/celery chunantes dryrunner.png b/assets/celery chunantes dryrunner.png deleted file mode 100644 index 343136dc3..000000000 Binary files a/assets/celery chunantes dryrunner.png and /dev/null differ diff --git a/assets/celery chunantes scheduler.png b/assets/celery chunantes scheduler.png index 444ce72ee..5d7a796aa 100644 Binary files a/assets/celery chunantes scheduler.png and b/assets/celery chunantes scheduler.png differ diff --git a/assets/celery chunantes worker.png b/assets/celery chunantes worker.png index 88e72109d..62b09bec6 100644 Binary files a/assets/celery chunantes worker.png and b/assets/celery chunantes worker.png differ diff --git a/assets/celery owkin dryrunner.png b/assets/celery owkin dryrunner.png deleted file mode 100644 index 60b95934d..000000000 Binary files a/assets/celery owkin dryrunner.png and /dev/null differ diff --git a/assets/celery owkin scheduler.png b/assets/celery owkin scheduler.png index 0bf3a6583..994637200 100644 Binary files a/assets/celery owkin scheduler.png and b/assets/celery owkin scheduler.png differ diff --git a/assets/celery owkin worker.png b/assets/celery owkin worker.png index c73ab3a5f..6739570bf 100644 Binary files a/assets/celery owkin worker.png and b/assets/celery owkin worker.png differ diff --git a/assets/celery_beat.png b/assets/celery_beat.png index d927db7f7..910d22db6 100644 Binary files a/assets/celery_beat.png and b/assets/celery_beat.png differ diff --git a/assets/database_owkin.png b/assets/database_owkin.png index 0a1331062..b4c1718ac 100644 Binary files a/assets/database_owkin.png and b/assets/database_owkin.png differ diff --git a/assets/database_owkin_challenges.png b/assets/database_owkin_challenges.png index fcc1fc13e..fdc2bea94 100644 Binary files a/assets/database_owkin_challenges.png and b/assets/database_owkin_challenges.png differ diff --git a/assets/django_enabled.png b/assets/django_enabled.png index f42484383..3a955f3cd 100644 Binary files a/assets/django_enabled.png and b/assets/django_enabled.png differ diff --git a/assets/multirun.png b/assets/multirun.png index 4b9d56b92..4b514ec83 100644 Binary files a/assets/multirun.png and b/assets/multirun.png differ diff --git a/assets/server_chunantes.png b/assets/server_chunantes.png index 6068e04ba..53909a785 100644 Binary files a/assets/server_chunantes.png and b/assets/server_chunantes.png differ diff --git a/assets/server_owkin.png b/assets/server_owkin.png index b05ac142a..60d2e83a3 100644 Binary files a/assets/server_owkin.png and b/assets/server_owkin.png differ diff --git a/assets/sources_root.png b/assets/sources_root.png index e1c65b08d..86aeeca36 100644 Binary files a/assets/sources_root.png and b/assets/sources_root.png differ diff --git a/substrabac/.coveragerc b/backend/.coveragerc similarity index 53% rename from substrabac/.coveragerc rename to backend/.coveragerc index a5b94c111..809edc1ed 100644 --- a/substrabac/.coveragerc +++ b/backend/.coveragerc @@ -4,8 +4,8 @@ source = ./substrapp/ omit = - ./substrapp/apps.py - ./substrapp/serializers/ledger/* + ./substrapp/tests/generate_assets.py + ./substrapp/management/* [report] # Regexes for lines to exclude from consideration @@ -17,20 +17,10 @@ exclude_lines = def __repr__ if self\.debug - # Don't complain if tests don't hit defensive assertion code: - # raise - # Don't complain if non-runnable code isn't run: if 0: if __name__ == .__main__.: - # Don't complain if exception code isn't run: - logging - # except: - # except Exception - # except FileNotFoundError - EXCEPTIONS_MAP = dict() - # Don't complain if no gpu during test if gpu_set if __gpu_list @@ -38,10 +28,5 @@ exclude_lines = if 'environment' in job_args # Ignore functions - # def queryLedger - # def invokeLedger - # def getObjectFromLedger - def prepareTrainingTask - def prepareTestingTask - # class CustomFileResponse - # class ManageFileMixin + def prepare_training_task + def prepare_testing_task diff --git a/substrabac/substrabac/__init__.py b/backend/backend/__init__.py similarity index 88% rename from substrabac/substrabac/__init__.py rename to backend/backend/__init__.py index d128d39cd..070e835d0 100644 --- a/substrabac/substrabac/__init__.py +++ b/backend/backend/__init__.py @@ -4,4 +4,4 @@ # Django starts so that shared_task will use this app. from .celery import app as celery_app -__all__ = ('celery_app',) \ No newline at end of file +__all__ = ('celery_app',) diff --git a/backend/backend/celery.py b/backend/backend/celery.py new file mode 100644 index 000000000..df260609a --- /dev/null +++ b/backend/backend/celery.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import, unicode_literals +import os +from celery import Celery +from celery import current_app +from celery.signals import after_task_publish + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings.prod') + +app = Celery('backend') + +# Using a string here means the worker doesn't have to serialize +# the configuration object to child processes. +# - namespace='CELERY' means all celery-related configuration keys +# should have a `CELERY_` prefix. +app.config_from_object('django.conf:settings', namespace='CELERY') + +# Load task modules from all registered Django app configs. +app.autodiscover_tasks() + + +@app.on_after_configure.connect +def setup_periodic_tasks(sender, **kwargs): + from substrapp.tasks.tasks import prepare_training_task, prepare_testing_task + + period = 3 * 3600 + sender.add_periodic_task(period, prepare_training_task.s(), queue='scheduler', + name='query Traintuples to prepare train task on todo traintuples') + sender.add_periodic_task(period, prepare_testing_task.s(), queue='scheduler', + name='query Testuples to prepare test task on todo testuples') + + +@after_task_publish.connect +def update_task_state(sender=None, headers=None, body=None, **kwargs): + # Change task.status to 'WAITING' for all tasks which are sent in. + # This allows one to distinguish between PENDING tasks which have been + # sent in and tasks which do not exist. State will change to + # SUCCESS, FAILURE, etc. once the process terminates. + task = current_app.tasks.get(sender) + backend = task.backend if task else current_app.backend + backend.store_result(headers['id'], None, 'WAITING') diff --git a/substrabac/substrabac/settings/__init__.py b/backend/backend/settings/__init__.py similarity index 100% rename from substrabac/substrabac/settings/__init__.py rename to backend/backend/settings/__init__.py diff --git a/substrabac/substrabac/settings/common.py b/backend/backend/settings/common.py similarity index 79% rename from substrabac/substrabac/settings/common.py rename to backend/backend/settings/common.py index 6598755b5..b1f0040db 100644 --- a/substrabac/substrabac/settings/common.py +++ b/backend/backend/settings/common.py @@ -1,5 +1,5 @@ """ -Django settings for substrabac project. +Django settings for backend project. Generated by 'django-admin startproject' using Django 2.0.5. @@ -10,7 +10,10 @@ https://docs.djangoproject.com/en/2.0/ref/settings/ """ -import os, sys, json +import os +import sys +from datetime import timedelta + from libs.gen_secret_key import write_secret_key # Build paths inside the project like this: os.path.join(BASE_DIR, ...) @@ -57,7 +60,16 @@ 'django_celery_results', 'rest_framework_swagger', 'rest_framework', + 'rest_framework.authtoken', + 'rest_framework_simplejwt.token_blacklist', 'substrapp', + 'node', + 'users' +] + +AUTHENTICATION_BACKENDS = [ + 'django.contrib.auth.backends.ModelBackend', + 'node.authentication.NodeBackend', ] MIDDLEWARE = [ @@ -66,12 +78,14 @@ 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.RemoteUserMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'libs.SQLPrintingMiddleware.SQLPrintingMiddleware', + 'libs.HealthCheckMiddleware.HealthCheckMiddleware', ] -ROOT_URLCONF = 'substrabac.urls' +ROOT_URLCONF = 'backend.urls' TEMPLATES = [ { @@ -89,7 +103,7 @@ }, ] -WSGI_APPLICATION = 'substrabac.wsgi.application' +WSGI_APPLICATION = 'backend.wsgi.application' # Database # https://docs.djangoproject.com/en/1.9/ref/settings/#databases @@ -105,11 +119,23 @@ # https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'libs.zxcvbnValidator.ZxcvbnValidator', + }, { 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', }, { 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + 'OPTIONS': { + 'min_length': 9, + } + }, + { + 'NAME': 'libs.maximumLengthValidator.MaximumLengthValidator', + 'OPTIONS': { + 'max_length': 64 + } }, { 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', @@ -140,23 +166,29 @@ MEDIA_ROOT = os.path.join(PROJECT_ROOT, 'medias') MEDIA_URL = '/media/' -REST_FRAMEWORK = { - 'DEFAULT_RENDERER_CLASSES': ( - 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.AdminRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', - ) -} - SITE_ID = 1 -LEDGER_SYNC_ENABLED = True - CELERY_RESULT_BACKEND = 'django-db' CELERY_ACCEPT_CONTENT = ['application/json'] CELERY_RESULT_SERIALIZER = 'json' CELERY_TASK_SERIALIZER = 'json' CELERY_TASK_TRACK_STARTED = True # since 4.0 CELERY_WORKER_CONCURRENCY = 1 - CELERY_BROKER_URL = os.environ.get('CELERY_BROKER_URL', 'amqp://localhost:5672//'), + +DATA_UPLOAD_MAX_NUMBER_FIELDS = 10000 + +EXPIRY_TOKEN_LIFETIME = timedelta(minutes=int(os.environ.get('EXPIRY_TOKEN_LIFETIME', 24*60))) + +TRUE_VALUES = { + 't', 'T', + 'y', 'Y', 'yes', 'YES', + 'true', 'True', 'TRUE', + 'on', 'On', 'ON', + '1', 1, + True +} + + +def to_bool(value): + return value in TRUE_VALUES diff --git a/substrabac/substrabac/settings/deps/__init__.py b/backend/backend/settings/deps/__init__.py similarity index 100% rename from substrabac/substrabac/settings/deps/__init__.py rename to backend/backend/settings/deps/__init__.py diff --git a/substrabac/substrabac/settings/deps/cors.py b/backend/backend/settings/deps/cors.py similarity index 100% rename from substrabac/substrabac/settings/deps/cors.py rename to backend/backend/settings/deps/cors.py diff --git a/backend/backend/settings/deps/ledger.py b/backend/backend/settings/deps/ledger.py new file mode 100644 index 000000000..66eda4120 --- /dev/null +++ b/backend/backend/settings/deps/ledger.py @@ -0,0 +1,251 @@ +import os +import base64 +import asyncio +import glob +import json +import tempfile + +from .org import ORG + +from hfc.fabric import Client +from hfc.fabric.peer import Peer +from hfc.fabric.user import create_user +from hfc.fabric.orderer import Orderer +from hfc.util.keyvaluestore import FileKeyValueStore +from hfc.fabric.block_decoder import decode_fabric_MSP_config, decode_fabric_peers_info, decode_fabric_endpoints + +SUBSTRA_FOLDER = os.getenv('SUBSTRA_PATH', '/substra') +LEDGER_CONFIG_FILE = os.environ.get('LEDGER_CONFIG_FILE', f'{SUBSTRA_FOLDER}/conf/{ORG}/substra-backend/conf.json') +LEDGER = json.load(open(LEDGER_CONFIG_FILE, 'r')) + +LEDGER_SYNC_ENABLED = True +LEDGER_CALL_RETRY = True + +PEER_PORT = LEDGER['peer']['port'][os.environ.get('BACKEND_PEER_PORT', 'external')] + +LEDGER['requestor'] = create_user( + name=LEDGER['client']['name'], + org=LEDGER['client']['org'], + state_store=FileKeyValueStore(LEDGER['client']['state_store']), + msp_id=LEDGER['client']['msp_id'], + key_path=glob.glob(LEDGER['client']['key_path'])[0], + cert_path=LEDGER['client']['cert_path'] +) + + +def get_hfc_client(): + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + client = Client() + + # Add peer from backend ledger config file + peer = Peer(name=LEDGER['peer']['name']) + peer.init_with_bundle({ + 'url': f'{LEDGER["peer"]["host"]}:{PEER_PORT}', + 'grpcOptions': LEDGER['peer']['grpcOptions'], + 'tlsCACerts': {'path': LEDGER['peer']['tlsCACerts']}, + 'clientKey': {'path': LEDGER['peer']['clientKey']}, + 'clientCert': {'path': LEDGER['peer']['clientCert']}, + }) + client._peers[LEDGER['peer']['name']] = peer + + # Check peer has joined channel + + response = loop.run_until_complete( + client.query_channels( + requestor=LEDGER['requestor'], + peers=[peer], + decode=True + ) + ) + + channels = [ch.channel_id for ch in response.channels] + + if not LEDGER['channel_name'] in channels: + raise Exception(f'Peer has not joined channel: {LEDGER["channel_name"]}') + + channel = client.new_channel(LEDGER['channel_name']) + + # Check chaincode is instantiated in the channel + + responses = loop.run_until_complete( + client.query_instantiated_chaincodes( + requestor=LEDGER['requestor'], + channel_name=LEDGER['channel_name'], + peers=[peer], + decode=True + ) + ) + + chaincodes = [cc.name + for resp in responses + for cc in resp.chaincodes] + + if not LEDGER['chaincode_name'] in chaincodes: + raise Exception(f'Chaincode : {LEDGER["chaincode_name"]}' + f' is not instantiated in the channel : {LEDGER["channel_name"]}') + + # Discover orderers and peers from channel discovery + results = loop.run_until_complete( + channel._discovery( + LEDGER['requestor'], + peer, + config=True, + local=False, + interests=[{'chaincodes': [{'name': LEDGER['chaincode_name']}]}] + ) + ) + + results = deserialize_discovery(results) + + update_client_with_discovery(client, results) + + return loop, client + + +LEDGER['hfc'] = get_hfc_client + + +def update_client_with_discovery(client, discovery_results): + + # Get all msp tls root cert files + tls_root_certs = {} + + for mspid, msp_info in discovery_results['config']['msps'].items(): + tls_root_certs[mspid] = base64.decodebytes( + msp_info['tls_root_certs'].pop().encode() + ) + + # Load one peer per msp for endorsing transaction + for msp in discovery_results['members']: + peer_info = msp[0] + if peer_info['mspid'] != LEDGER['client']['msp_id']: + peer = Peer(name=peer_info['mspid']) + + with tempfile.NamedTemporaryFile() as tls_root_cert: + tls_root_cert.write(tls_root_certs[peer_info['mspid']]) + tls_root_cert.flush() + + url = peer_info['endpoint'] + external_port = os.environ.get('BACKEND_PEER_PORT_EXTERNAL', None) + # use case for external development + if external_port: + url = f"{peer_info['endpoint'].split(':')[0]}:{external_port}" + peer.init_with_bundle({ + 'url': url, + 'grpcOptions': { + 'grpc-max-send-message-length': 15, + 'grpc.ssl_target_name_override': peer_info['endpoint'].split(':')[0] + }, + 'tlsCACerts': {'path': tls_root_cert.name}, + 'clientKey': {'path': LEDGER['peer']['clientKey']}, # use peer creds (mutual tls) + 'clientCert': {'path': LEDGER['peer']['clientCert']}, # use peer creds (mutual tls) + }) + + client._peers[peer_info['mspid']] = peer + + # Load one orderer for broadcasting transaction + orderer_mspid, orderer_info = list(discovery_results['config']['orderers'].items())[0] + + orderer = Orderer(name=orderer_mspid) + + with tempfile.NamedTemporaryFile() as tls_root_cert: + tls_root_cert.write(tls_root_certs[orderer_mspid]) + tls_root_cert.flush() + + # Need loop + orderer.init_with_bundle({ + 'url': f"{orderer_info[0]['host']}:{orderer_info[0]['port']}", + 'grpcOptions': { + 'grpc-max-send-message-length': 15, + 'grpc.ssl_target_name_override': orderer_info[0]['host'] + }, + 'tlsCACerts': {'path': tls_root_cert.name}, + 'clientKey': {'path': LEDGER['peer']['clientKey']}, # use peer creds (mutual tls) + 'clientCert': {'path': LEDGER['peer']['clientCert']}, # use peer creds (mutual tls) + }) + + client._orderers[orderer_mspid] = orderer + + +def deserialize_discovery(response): + results = { + 'config': None, + 'members': [], + 'cc_query_res': None + } + + for res in response.results: + if res.config_result and res.config_result.msps and res.config_result.orderers: + results['config'] = deserialize_config(res.config_result) + + if res.members: + results['members'].extend(deserialize_members(res.members)) + + if res.cc_query_res and res.cc_query_res.content: + results['cc_query_res'] = deserialize_cc_query_res(res.cc_query_res) + + return results + + +def deserialize_config(config_result): + + results = {'msps': {}, + 'orderers': {}} + + for mspid in config_result.msps: + results['msps'][mspid] = decode_fabric_MSP_config( + config_result.msps[mspid].SerializeToString() + ) + + for mspid in config_result.orderers: + results['orderers'][mspid] = decode_fabric_endpoints( + config_result.orderers[mspid].endpoint + ) + + return results + + +def deserialize_members(members): + peers = [] + + for mspid in members.peers_by_org: + peer = decode_fabric_peers_info( + members.peers_by_org[mspid].peers + ) + peers.append(peer) + + return peers + + +def deserialize_cc_query_res(cc_query_res): + cc_queries = [] + + for cc_query_content in cc_query_res.content: + cc_query = { + 'chaincode': cc_query_content.chaincode, + 'endorsers_by_groups': {}, + 'layouts': [] + } + + for group in cc_query_content.endorsers_by_groups: + peers = decode_fabric_peers_info( + cc_query_content.endorsers_by_groups[group].peers + ) + + cc_query['endorsers_by_groups'][group] = peers + + for layout_content in cc_query_content.layouts: + layout = { + 'quantities_by_group': { + group: int(layout_content.quantities_by_group[group]) + for group in layout_content.quantities_by_group + } + } + cc_query['layouts'].append(layout) + + cc_queries.append(cc_query) + + return cc_queries diff --git a/backend/backend/settings/deps/org.py b/backend/backend/settings/deps/org.py new file mode 100644 index 000000000..ffe75c9b9 --- /dev/null +++ b/backend/backend/settings/deps/org.py @@ -0,0 +1,6 @@ +import os + +ORG = os.environ.get('BACKEND_ORG', 'substra') +DEFAULT_PORT = os.environ.get('BACKEND_DEFAULT_PORT', '8000') +ORG_NAME = ORG.replace('-', '') +ORG_DB_NAME = ORG.replace('-', '_').upper() diff --git a/substrabac/substrabac/settings/deps/raven.py b/backend/backend/settings/deps/raven.py similarity index 53% rename from substrabac/substrabac/settings/deps/raven.py rename to backend/backend/settings/deps/raven.py index 74d4479f1..e21214d6c 100644 --- a/substrabac/substrabac/settings/deps/raven.py +++ b/backend/backend/settings/deps/raven.py @@ -3,6 +3,6 @@ from sentry_sdk.integrations.django import DjangoIntegration sentry_sdk.init( - dsn=os.environ.get('RAVEN_URL', "https://cff352ba26fc49f19e01692db93bf951@sentry.io/1317743"), # Default to substrabac raven + dsn=os.environ.get("RAVEN_URL"), integrations=[DjangoIntegration()] ) diff --git a/backend/backend/settings/deps/restframework.py b/backend/backend/settings/deps/restframework.py new file mode 100644 index 000000000..a0e4dbe77 --- /dev/null +++ b/backend/backend/settings/deps/restframework.py @@ -0,0 +1,28 @@ +import os +from datetime import timedelta + +REST_FRAMEWORK = { + 'TEST_REQUEST_DEFAULT_FORMAT': 'json', + 'DEFAULT_RENDERER_CLASSES': ( + 'rest_framework.renderers.JSONRenderer', + # 'rest_framework.renderers.AdminRenderer', + 'rest_framework.renderers.BrowsableAPIRenderer', + ), + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'users.authentication.SecureJWTAuthentication', # for front/sdk/cli + 'libs.expiryTokenAuthentication.ExpiryTokenAuthentication', # for front/sdk/cli + 'libs.sessionAuthentication.CustomSessionAuthentication', # for web browsable api + ], + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ], + 'UNICODE_JSON': False, + 'DEFAULT_VERSIONING_CLASS': 'libs.versioning.AcceptHeaderVersioningRequired', + 'ALLOWED_VERSIONS': ('0.0',), + 'DEFAULT_VERSION': '0.0', +} + +SIMPLE_JWT = { + 'ACCESS_TOKEN_LIFETIME': timedelta(minutes=int(os.environ.get('ACCESS_TOKEN_LIFETIME', 24*60))), + 'AUTH_HEADER_TYPES': ('JWT',), +} diff --git a/substrabac/substrabac/settings/dev.py b/backend/backend/settings/dev.py similarity index 56% rename from substrabac/substrabac/settings/dev.py rename to backend/backend/settings/dev.py index f9c580123..efd31d313 100644 --- a/substrabac/substrabac/settings/dev.py +++ b/backend/backend/settings/dev.py @@ -1,23 +1,24 @@ import os from .common import * - -from .deps.restframework import * from .deps.cors import * +from .deps.org import * +from .deps.ledger import * +from .deps.restframework import * -DEBUG = True +BASICAUTH_USERNAME = os.environ.get('BACK_AUTH_USER', 'dev') +BASICAUTH_PASSWORD = os.environ.get('BACK_AUTH_PASSWORD', 'dev') -ORG = os.environ.get('SUBSTRABAC_ORG', 'substra') -DEFAULT_PORT = os.environ.get('SUBSTRABAC_DEFAULT_PORT', '8000') +DEBUG = True -ORG_NAME = ORG.replace('-', '') -ORG_DB_NAME = ORG.replace('-', '_').upper() +TASK = { + 'CAPTURE_LOGS': to_bool(os.environ.get('TASK_CAPTURE_LOGS', True)), + 'CLEAN_EXECUTION_ENVIRONMENT': to_bool(os.environ.get('TASK_CLEAN_EXECUTION_ENVIRONMENT', True)), + 'CACHE_DOCKER_IMAGES': to_bool(os.environ.get('TASK_CACHE_DOCKER_IMAGES', False)), +} -try: - LEDGER = json.load(open(f'/substra/conf/{ORG}/substrabac/conf.json', 'r')) -except: - pass +LEDGER_CALL_RETRY = False # Overwrite the ledger setting value # Database # https://docs.djangoproject.com/en/2.0/ref/settings/#databases @@ -25,24 +26,18 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql_psycopg2', - 'NAME': os.environ.get(f'SUBSTRABAC_{ORG_DB_NAME}_DB_NAME', f'substrabac_{ORG_NAME}'), - 'USER': os.environ.get('SUBSTRABAC_DB_USER', 'substrabac'), - 'PASSWORD': os.environ.get('SUBSTRABAC_DB_PWD', 'substrabac'), + 'NAME': os.environ.get(f'BACKEND_{ORG_DB_NAME}_DB_NAME', f'backend_{ORG_NAME}'), + 'USER': os.environ.get('BACKEND_DB_USER', 'backend'), + 'PASSWORD': os.environ.get('BACKEND_DB_PWD', 'backend'), 'HOST': os.environ.get('DATABASE_HOST', 'localhost'), 'PORT': 5432, } } MEDIA_ROOT = os.environ.get('MEDIA_ROOT', os.path.join(PROJECT_ROOT, f'medias/{ORG_NAME}')) -DRYRUN_ROOT = os.environ.get('DRYRUN_ROOT', os.path.join(PROJECT_ROOT, f'dryrun/{ORG}')) - -if not os.path.exists(DRYRUN_ROOT): - os.makedirs(DRYRUN_ROOT, exist_ok=True) -SITE_ID = 1 -SITE_HOST = f'{ORG_NAME}.substrabac' +SITE_HOST = f'substra-backend.{ORG_NAME}.xyz' SITE_PORT = DEFAULT_PORT - DEFAULT_DOMAIN = os.environ.get('DEFAULT_DOMAIN', f'http://{SITE_HOST}:{SITE_PORT}') LOGGING = { @@ -53,7 +48,7 @@ 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' }, 'simple': { - 'format': '%(levelname)s %(message)s' + 'format': '%(levelname)s - %(asctime)s - %(name)s - %(message)s' }, }, 'filters': { @@ -67,9 +62,14 @@ 'filters': ['require_debug_false'], 'class': 'django.utils.log.AdminEmailHandler' }, + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'simple' + }, 'error_file': { 'level': 'INFO', - 'filename': os.path.join(PROJECT_ROOT, 'substrabac.log'), + 'filename': os.path.join(PROJECT_ROOT, 'backend.log'), 'class': 'logging.handlers.RotatingFileHandler', 'maxBytes': 1 * 1024 * 1024, 'backupCount': 2, @@ -80,6 +80,11 @@ 'django.request': { 'handlers': ['mail_admins', 'error_file'], 'level': 'INFO', + 'propagate': False, + }, + 'events': { + 'handlers': ['console'], + 'level': 'DEBUG', 'propagate': True, }, } diff --git a/substrabac/substrabac/settings/prod.py b/backend/backend/settings/prod.py similarity index 68% rename from substrabac/substrabac/settings/prod.py rename to backend/backend/settings/prod.py index 40cb1fee6..7b0371157 100644 --- a/substrabac/substrabac/settings/prod.py +++ b/backend/backend/settings/prod.py @@ -1,26 +1,31 @@ import os from .common import * - -from .deps.restframework import * from .deps.cors import * from .deps.raven import * +from .deps.org import * +from .deps.ledger import * +from .deps.restframework import * + DEBUG = False -USE_X_FORWARDED_HOST = True -SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') -os.environ['HTTPS'] = "on" -os.environ['wsgi.url_scheme'] = 'https' # safer -import os +TASK = { + 'CAPTURE_LOGS': to_bool(os.environ.get('TASK_CAPTURE_LOGS', True)), + 'CLEAN_EXECUTION_ENVIRONMENT': to_bool(os.environ.get('TASK_CLEAN_EXECUTION_ENVIRONMENT', True)), + 'CACHE_DOCKER_IMAGES': to_bool(os.environ.get('TASK_CACHE_DOCKER_IMAGES', False)), +} -ORG = os.environ.get('SUBSTRABAC_ORG', 'substra') -DEFAULT_PORT = os.environ.get('SUBSTRABAC_DEFAULT_PORT', '8000') +BASICAUTH_USERNAME = os.environ.get('BACK_AUTH_USER') +BASICAUTH_PASSWORD = os.environ.get('BACK_AUTH_PASSWORD') -ORG_NAME = ORG.replace('-', '') -ORG_DB_NAME = ORG.replace('-', '_').upper() +USE_X_FORWARDED_HOST = True +SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') +os.environ['HTTPS'] = "on" +os.environ['wsgi.url_scheme'] = 'https' -LEDGER = json.load(open(f'/substra/conf/{ORG}/substrabac/conf.json', 'r')) +STATIC_URL = '/static/' +STATIC_ROOT = os.path.join(BASE_DIR, 'statics') # Database # https://docs.djangoproject.com/en/2.0/ref/settings/#databases @@ -28,32 +33,20 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql_psycopg2', - 'NAME': os.environ.get(f'SUBSTRABAC_{ORG_DB_NAME}_DB_NAME', f'substrabac_{ORG_NAME}'), - 'USER': os.environ.get('SUBSTRABAC_DB_USER', 'substrabac'), - 'PASSWORD': os.environ.get('SUBSTRABAC_DB_PWD', 'substrabac'), + 'NAME': os.environ.get(f'BACKEND_{ORG_DB_NAME}_DB_NAME', f'backend_{ORG_NAME}'), + 'USER': os.environ.get('BACKEND_DB_USER', 'backend'), + 'PASSWORD': os.environ.get('BACKEND_DB_PWD', 'backend'), 'HOST': os.environ.get('DATABASE_HOST', 'localhost'), 'PORT': 5432, } } -MEDIA_ROOT = f'/substra/medias/{ORG_NAME}' -DRYRUN_ROOT = f'/substra/dryrun/{ORG}' +MEDIA_ROOT = os.environ.get('MEDIA_ROOT', f'/substra/medias/{ORG_NAME}') -SITE_ID = 1 -SITE_HOST = os.environ.get('SITE_HOST', f'{ORG_NAME}.substrabac') +SITE_HOST = os.environ.get('SITE_HOST', f'substra-backend.{ORG_NAME}.xyz') SITE_PORT = os.environ.get('SITE_PORT', DEFAULT_PORT) - DEFAULT_DOMAIN = os.environ.get('DEFAULT_DOMAIN', f'http://{SITE_HOST}:{SITE_PORT}') -STATIC_URL = '/static/' -STATIC_ROOT = os.path.join(BASE_DIR, 'statics') - -# deactivate when public -BASICAUTH_USERNAME = os.environ.get('BACK_AUTH_USER', None) -BASICAUTH_PASSWORD = os.environ.get('BACK_AUTH_PASSWORD', None) -MIDDLEWARE += ['libs.BasicAuthMiddleware.BasicAuthMiddleware'] - - LOGGING = { 'version': 1, 'disable_existing_loggers': False, @@ -85,12 +78,12 @@ 'error_file': { 'class': 'logging.FileHandler', 'formatter': 'generic', - 'filename': '/var/log/substrabac.error.log', + 'filename': '/var/log/substra-backend.error.log', }, 'access_file': { 'class': 'logging.FileHandler', 'formatter': 'generic', - 'filename': '/var/log/substrabac.access.log', + 'filename': '/var/log/substra-backend.access.log', }, }, 'loggers': { diff --git a/substrabac/substrapp/fixtures/__init__.py b/backend/backend/settings/server/__init__.py similarity index 100% rename from substrabac/substrapp/fixtures/__init__.py rename to backend/backend/settings/server/__init__.py diff --git a/backend/backend/settings/server/dev.py b/backend/backend/settings/server/dev.py new file mode 100644 index 000000000..e6b079217 --- /dev/null +++ b/backend/backend/settings/server/dev.py @@ -0,0 +1,3 @@ +from ..dev import * + +INSTALLED_APPS += ['events', 'node-register'] diff --git a/substrabac/substrapp/management/utils/__init__.py b/backend/backend/settings/server/nobasicauth/__init__.py similarity index 100% rename from substrabac/substrapp/management/utils/__init__.py rename to backend/backend/settings/server/nobasicauth/__init__.py diff --git a/backend/backend/settings/server/nobasicauth/dev.py b/backend/backend/settings/server/nobasicauth/dev.py new file mode 100644 index 000000000..7c6c71cbc --- /dev/null +++ b/backend/backend/settings/server/nobasicauth/dev.py @@ -0,0 +1,3 @@ +from ..dev import * + +BASIC_AUTHENTICATION_MODULE = 'substrapp.views.utils' diff --git a/backend/backend/settings/server/nobasicauth/prod.py b/backend/backend/settings/server/nobasicauth/prod.py new file mode 100644 index 000000000..07963050b --- /dev/null +++ b/backend/backend/settings/server/nobasicauth/prod.py @@ -0,0 +1,3 @@ +from ..prod import * + +BASIC_AUTHENTICATION_MODULE = 'substrapp.views.utils' diff --git a/backend/backend/settings/server/prod.py b/backend/backend/settings/server/prod.py new file mode 100644 index 000000000..1b18b4cee --- /dev/null +++ b/backend/backend/settings/server/prod.py @@ -0,0 +1,3 @@ +from ..prod import * + +INSTALLED_APPS += ['events', 'node-register'] diff --git a/substrabac/substrabac/settings/test.py b/backend/backend/settings/test.py similarity index 62% rename from substrabac/substrabac/settings/test.py rename to backend/backend/settings/test.py index 262350ce2..add578481 100644 --- a/substrabac/substrabac/settings/test.py +++ b/backend/backend/settings/test.py @@ -1,6 +1,6 @@ -import os - from .common import * - -from .deps.restframework import * from .deps.cors import * +from .deps.restframework import * + +import logging +logging.disable(logging.CRITICAL) diff --git a/substrabac/substrabac/urls.py b/backend/backend/urls.py similarity index 66% rename from substrabac/substrabac/urls.py rename to backend/backend/urls.py index 26aadf7fa..fecb890f7 100644 --- a/substrabac/substrabac/urls.py +++ b/backend/backend/urls.py @@ -1,4 +1,4 @@ -"""substrabac URL Configuration +"""backend URL Configuration The `urlpatterns` list routes URLs to views. For more information please see: https://docs.djangoproject.com/en/2.0/topics/http/urls/ @@ -19,8 +19,11 @@ from django.conf.urls.static import static from django.urls import include -from substrabac.views import schema_view +from backend.views import schema_view, obtain_auth_token + from substrapp.urls import router +from node.urls import router as nodeRouter +from users.urls import router as userRouter urlpatterns = [ @@ -28,5 +31,10 @@ url(r'^admin/', admin.site.urls), url(r'^doc/', schema_view), url(r'^', include((router.urls, 'substrapp'))), + url(r'^', include((nodeRouter.urls, 'node'))), + url(r'^', include((userRouter.urls, 'user'))), # for secure jwt authent + url(r'^api-auth/', include('rest_framework.urls')), # for session authent + url(r'^api-token-auth/', obtain_auth_token) # for expiry token authent ])), -] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) +] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) \ + + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) diff --git a/substrabac/substrabac/views.py b/backend/backend/views.py similarity index 80% rename from substrabac/substrabac/views.py rename to backend/backend/views.py index 6fecb285b..abb1450c9 100644 --- a/substrabac/substrabac/views.py +++ b/backend/backend/views.py @@ -1,11 +1,18 @@ import yaml + +from rest_framework.authtoken.views import ObtainAuthToken from rest_framework.decorators import api_view, renderer_classes from rest_framework import response, schemas from rest_framework_swagger.renderers import OpenAPIRenderer, SwaggerUIRenderer +from rest_framework.compat import coreapi + +from rest_framework.authtoken.models import Token +from rest_framework.response import Response from django.conf.urls import url, include + +from libs.expiryTokenAuthentication import token_expire_handler, expires_at from substrapp.urls import router -from rest_framework.compat import coreapi from requests.compat import urlparse @@ -49,7 +56,7 @@ def get_link(self, path, method, view): if len(a) == 2: try: yaml_doc = yaml.load(a[1]) - except: + except BaseException: pass else: if 'desc' in yaml_doc: @@ -102,6 +109,27 @@ def get_link(self, path, method, view): @renderer_classes([OpenAPIRenderer, SwaggerUIRenderer]) def schema_view(request): generator = SchemaGenerator( - title='Substrabac API', + title='Substra Backend API', patterns=[url(r'^/', include([url(r'^', include(router.urls))]))]) return response.Response(generator.get_schema(request=request)) + + +class ExpiryObtainAuthToken(ObtainAuthToken): + + def post(self, request, *args, **kwargs): + serializer = self.serializer_class(data=request.data, + context={'request': request}) + serializer.is_valid(raise_exception=True) + user = serializer.validated_data['user'] + token, created = Token.objects.get_or_create(user=user) + + # token_expire_handler will check, if the token is expired it will generate new one + is_expired, token = token_expire_handler(token) + + return Response({ + 'token': token.key, + 'expires_at': expires_at(token) + }) + + +obtain_auth_token = ExpiryObtainAuthToken.as_view() diff --git a/substrabac/substrabac/wsgi.py b/backend/backend/wsgi.py similarity index 72% rename from substrabac/substrabac/wsgi.py rename to backend/backend/wsgi.py index 79ab173a2..7bd8cebfc 100644 --- a/substrabac/substrabac/wsgi.py +++ b/backend/backend/wsgi.py @@ -1,5 +1,5 @@ """ -WSGI config for substrabac project. +WSGI config for backend project. It exposes the WSGI callable as a module-level variable named ``application``. @@ -11,6 +11,6 @@ from django.core.wsgi import get_wsgi_application -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "substrabac.settings.prod") +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings.prod") application = get_wsgi_application() diff --git a/backend/events/__init__.py b/backend/events/__init__.py new file mode 100644 index 000000000..9cb9c2d5a --- /dev/null +++ b/backend/events/__init__.py @@ -0,0 +1 @@ +default_app_config = 'events.apps.EventsConfig' diff --git a/backend/events/apps.py b/backend/events/apps.py new file mode 100644 index 000000000..cc5a10a4c --- /dev/null +++ b/backend/events/apps.py @@ -0,0 +1,150 @@ +import asyncio +import json +import logging +import multiprocessing +import os +import time +import contextlib + +from django.apps import AppConfig + +from django.conf import settings + +import glob + +from hfc.fabric import Client +from hfc.fabric.peer import Peer +from hfc.fabric.user import create_user +from hfc.util.keyvaluestore import FileKeyValueStore + +from substrapp.tasks.tasks import prepare_tuple +from substrapp.utils import get_owner +from substrapp.ledger_utils import get_hfc + +from celery.result import AsyncResult + +logger = logging.getLogger(__name__) +LEDGER = getattr(settings, 'LEDGER', None) + + +@contextlib.contextmanager +def get_event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + yield loop + finally: + loop.close() + + +def on_tuples(cc_event, block_number, tx_id, tx_status): + payload = json.loads(cc_event['payload']) + owner = get_owner() + worker_queue = f"{LEDGER['name']}.worker" + + for tuple_type, _tuples in payload.items(): + if not _tuples: + continue + + for _tuple in _tuples: + key = _tuple['key'] + status = _tuple['status'] + + logger.info(f'Processing task {key}: type={tuple_type} status={status}' + f' with tx status: {tx_status}') + + if status != 'todo': + continue + + if tuple_type is None: + continue + + tuple_owner = _tuple['dataset']['worker'] + if tuple_owner != owner: + logger.debug(f'Skipping task {key}: owner does not match' + f' ({tuple_owner} vs {owner})') + continue + + if AsyncResult(key).state != 'PENDING': + logger.info(f'Skipping task {key}: already exists') + continue + + prepare_tuple.apply_async( + (_tuple, tuple_type), + task_id=key, + queue=worker_queue + ) + + +def wait(): + with get_event_loop() as loop: + channel_name = LEDGER['channel_name'] + chaincode_name = LEDGER['chaincode_name'] + peer = LEDGER['peer'] + + peer_port = peer["port"][os.environ.get('BACKEND_PEER_PORT', 'external')] + + client = Client() + + channel = client.new_channel(channel_name) + + target_peer = Peer(name=peer['name']) + requestor_config = LEDGER['client'] + + target_peer.init_with_bundle({ + 'url': f'{peer["host"]}:{peer_port}', + 'grpcOptions': peer['grpcOptions'], + 'tlsCACerts': {'path': peer['tlsCACerts']}, + 'clientKey': {'path': peer['clientKey']}, + 'clientCert': {'path': peer['clientCert']}, + }) + + try: + # can fail + requestor = create_user( + name=requestor_config['name'] + '_events', + org=requestor_config['org'], + state_store=FileKeyValueStore(requestor_config['state_store']), + msp_id=requestor_config['msp_id'], + key_path=glob.glob(requestor_config['key_path'])[0], + cert_path=requestor_config['cert_path'] + ) + except BaseException: + pass + else: + channel_event_hub = channel.newChannelEventHub(target_peer, + requestor) + + # use chaincode event + + # uncomment this line if you want to replay blocks from the beginning for debugging purposes + # stream = channel_event_hub.connect(start=0, filtered=False) + stream = channel_event_hub.connect(filtered=False) + + channel_event_hub.registerChaincodeEvent(chaincode_name, + 'tuples-updated', + onEvent=on_tuples) + + loop.run_until_complete(stream) + + +class EventsConfig(AppConfig): + name = 'events' + + def ready(self): + + # We try to connect a client first, if it fails the backend will not start + # It avoid potential issue when we launch the channel event hub in a subprocess + while True: + try: + with get_hfc() as (loop, client): + logger.info('Start the event application.') + except Exception as e: + logger.exception(e) + time.sleep(5) + logger.info('Retry to connect the event application to the ledger') + else: + break + + p1 = multiprocessing.Process(target=wait) + p1.start() diff --git a/backend/libs/HealthCheckMiddleware.py b/backend/libs/HealthCheckMiddleware.py new file mode 100644 index 000000000..cee4271bb --- /dev/null +++ b/backend/libs/HealthCheckMiddleware.py @@ -0,0 +1,27 @@ +from django.http import HttpResponse + + +class HealthCheckMiddleware(object): + def __init__(self, get_response): + self.get_response = get_response + # One-time configuration and initialization. + + def __call__(self, request): + if request.method == "GET": + if request.path == "/readiness": + return self.readiness(request) + elif request.path == "/liveness": + return self.liveness(request) + return self.get_response(request) + + def liveness(self, request): + """ + Returns that the server is alive. + """ + return HttpResponse("OK") + + def readiness(self, request): + """ + Returns that the server is alive. + """ + return HttpResponse("OK") diff --git a/substrabac/libs/SQLPrintingMiddleware.py b/backend/libs/SQLPrintingMiddleware.py similarity index 98% rename from substrabac/libs/SQLPrintingMiddleware.py rename to backend/libs/SQLPrintingMiddleware.py index 6944596da..1293e2625 100644 --- a/substrabac/libs/SQLPrintingMiddleware.py +++ b/backend/libs/SQLPrintingMiddleware.py @@ -4,7 +4,6 @@ from django.conf import settings from django.db import connection -__author__ = 'guillaume' """ Originally code was taken from http://djangosnippets.org/snippets/290/ diff --git a/substrabac/substrapp/migrations/__init__.py b/backend/libs/__init__.py similarity index 100% rename from substrabac/substrapp/migrations/__init__.py rename to backend/libs/__init__.py diff --git a/backend/libs/expiryTokenAuthentication.py b/backend/libs/expiryTokenAuthentication.py new file mode 100644 index 000000000..1e470eb3a --- /dev/null +++ b/backend/libs/expiryTokenAuthentication.py @@ -0,0 +1,48 @@ +from django.conf import settings +from rest_framework.authentication import TokenAuthentication +from rest_framework.authtoken.models import Token +from rest_framework.exceptions import AuthenticationFailed + +from datetime import timedelta +from django.utils import timezone + + +# this return left time +def expires_at(token): + time_elapsed = timezone.now() - token.created + left_time = getattr(settings, 'EXPIRY_TOKEN_LIFETIME') - time_elapsed + return left_time + + +# token checker if token expired or not +def is_token_expired(token): + return expires_at(token) < timedelta(seconds=0) + + +# if token is expired new token will be established +# If token is expired then it will be removed +# and new one with different key will be created +def token_expire_handler(token): + is_expired = is_token_expired(token) + if is_expired: + token.delete() + token = Token.objects.create(user=token.user) + return is_expired, token + + +class ExpiryTokenAuthentication(TokenAuthentication): + """ + If token is expired then it will be removed + and new one with different key will be created + """ + + def authenticate_credentials(self, key): + + _, token = super(ExpiryTokenAuthentication, self).authenticate_credentials(key) + + is_expired = is_token_expired(token) + if is_expired: + token.delete() + raise AuthenticationFailed('The Token is expired') + + return (token.user, token) diff --git a/substrabac/libs/gen_secret_key.py b/backend/libs/gen_secret_key.py similarity index 100% rename from substrabac/libs/gen_secret_key.py rename to backend/libs/gen_secret_key.py diff --git a/backend/libs/maximumLengthValidator.py b/backend/libs/maximumLengthValidator.py new file mode 100644 index 000000000..61679d8a4 --- /dev/null +++ b/backend/libs/maximumLengthValidator.py @@ -0,0 +1,29 @@ +from django.core.exceptions import ValidationError +from django.utils.translation import ngettext + + +class MaximumLengthValidator: + """ + Validate whether the password is of a maximum length. + """ + def __init__(self, max_length=64): + self.max_length = max_length + + def validate(self, password, user=None): + if len(password) > self.max_length: + raise ValidationError( + ngettext( + "This password is too long. It must contain a maximum of %(max_length)d character.", + "This password is too long. It must contain a maximum of %(max_length)d characters.", + self.max_length + ), + code='password_too_long', + params={'max_length': self.max_length}, + ) + + def get_help_text(self): + return ngettext( + "Your password must contain a maximum of %(max_length)d character.", + "Your password must contain a maximum of %(max_length)d characters.", + self.max_length + ) % {'max_length': self.max_length} diff --git a/substrabac/libs/pagination.py b/backend/libs/pagination.py similarity index 89% rename from substrabac/libs/pagination.py rename to backend/libs/pagination.py index 397bc6b10..0e342a6c5 100644 --- a/substrabac/libs/pagination.py +++ b/backend/libs/pagination.py @@ -3,8 +3,6 @@ from __future__ import unicode_literals, absolute_import from rest_framework.pagination import PageNumberPagination -__author__ = 'guillaume' - class LimitedPagination(PageNumberPagination): page_size = 30 diff --git a/substrabac/libs/serializers.py b/backend/libs/serializers.py similarity index 97% rename from substrabac/libs/serializers.py rename to backend/libs/serializers.py index 9223dbf1a..35d815adf 100644 --- a/substrabac/libs/serializers.py +++ b/backend/libs/serializers.py @@ -3,8 +3,6 @@ from __future__ import unicode_literals, absolute_import from rest_framework import serializers -__author__ = 'guillaume' - class DynamicFieldsModelSerializer(serializers.ModelSerializer): """ diff --git a/backend/libs/sessionAuthentication.py b/backend/libs/sessionAuthentication.py new file mode 100644 index 000000000..85f28544c --- /dev/null +++ b/backend/libs/sessionAuthentication.py @@ -0,0 +1,19 @@ +from rest_framework.authentication import SessionAuthentication + + +class CustomSessionAuthentication(SessionAuthentication): + """ + Use Django's session framework for authentication. + """ + + def authenticate(self, request): + """ + Returns a `User` if the request session currently has a logged in user. + Otherwise returns `None`. + """ + + # bypass for login with jwt + if request.resolver_match.url_name == 'user-login': + return None + + return super(CustomSessionAuthentication, self).authenticate(request) diff --git a/substrabac/libs/timestampModel.py b/backend/libs/timestampModel.py similarity index 100% rename from substrabac/libs/timestampModel.py rename to backend/libs/timestampModel.py diff --git a/substrabac/libs/versioning.py b/backend/libs/versioning.py similarity index 100% rename from substrabac/libs/versioning.py rename to backend/libs/versioning.py diff --git a/backend/libs/zxcvbnValidator.py b/backend/libs/zxcvbnValidator.py new file mode 100644 index 000000000..aadc1a474 --- /dev/null +++ b/backend/libs/zxcvbnValidator.py @@ -0,0 +1,19 @@ +from django.core.exceptions import ValidationError +from django.utils.translation import gettext as _ +from zxcvbn import zxcvbn + + +class ZxcvbnValidator: + + def validate(self, password, user=None): + results = zxcvbn(password, user_inputs=[user]) + + # score to the password, from 0 (terrible) to 4 (great) + if results['score'] < 3: + str = 'This password is not complex enough.' + if results['feedback']['warning']: + str += f"\nwarning: {results['feedback']['warning']}" + raise ValidationError(_(str), code='password_not_complex') + + def get_help_text(self): + return _("Your password must be a complex one") diff --git a/substrabac/manage.py b/backend/manage.py similarity index 85% rename from substrabac/manage.py rename to backend/manage.py index 570add98a..09333153a 100755 --- a/substrabac/manage.py +++ b/backend/manage.py @@ -3,7 +3,7 @@ import sys if __name__ == "__main__": - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "substrabac.settings.dev") + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings.dev") try: from django.core.management import execute_from_command_line except ImportError as exc: diff --git a/backend/node-register/__init__.py b/backend/node-register/__init__.py new file mode 100644 index 000000000..39b356278 --- /dev/null +++ b/backend/node-register/__init__.py @@ -0,0 +1 @@ +default_app_config = 'node-register.apps.NodeRegisterConfig' diff --git a/backend/node-register/apps.py b/backend/node-register/apps.py new file mode 100644 index 000000000..40875adf7 --- /dev/null +++ b/backend/node-register/apps.py @@ -0,0 +1,10 @@ +from django.apps import AppConfig +from substrapp.ledger_utils import invoke_ledger + + +class NodeRegisterConfig(AppConfig): + name = 'node-register' + + def ready(self): + # args is set to empty because fabric-sdk-py doesn't allow None args for invoke operations + invoke_ledger(fcn='registerNode', args=[''], sync=True) diff --git a/backend/node/__init__.py b/backend/node/__init__.py new file mode 100644 index 000000000..00990e886 --- /dev/null +++ b/backend/node/__init__.py @@ -0,0 +1 @@ +default_app_config = 'node.apps.NodeConfig' diff --git a/backend/node/apps.py b/backend/node/apps.py new file mode 100644 index 000000000..311682ad9 --- /dev/null +++ b/backend/node/apps.py @@ -0,0 +1,12 @@ +from django.apps import AppConfig +from django.db.models.signals import pre_save + + +class NodeConfig(AppConfig): + name = 'node' + + def ready(self): + from node.models import IncomingNode + from node.signals.node.pre_save import node_pre_save + + pre_save.connect(node_pre_save, sender=IncomingNode) diff --git a/backend/node/authentication.py b/backend/node/authentication.py new file mode 100644 index 000000000..2466e9894 --- /dev/null +++ b/backend/node/authentication.py @@ -0,0 +1,32 @@ +from django.contrib.auth.models import User +from django.core.exceptions import ObjectDoesNotExist + +from .models import IncomingNode + + +class NodeUser(User): + pass + + +# TODO: should be removed when node to node authent will be done via certificates +class NodeBackend: + """Authenticate node """ + + def authenticate(self, request, username=None, password=None): + """Check the username/password and return a user.""" + if not username or not password: + return None + + try: + node = IncomingNode.objects.get(node_id=username) + except ObjectDoesNotExist: + return None + else: + if node.check_password(password): + return NodeUser(username=username) + + return None + + def get_user(self, user_id): + # required for session + return None diff --git a/backend/node/generate_nodes.py b/backend/node/generate_nodes.py new file mode 100755 index 000000000..d1c336048 --- /dev/null +++ b/backend/node/generate_nodes.py @@ -0,0 +1,64 @@ +import json +import os +import secrets + + +def generate_secret(): + return secrets.token_hex(64) + + +def generate(orgs): + files = {} + + # TODO merge two loops + # init file content + for org in orgs: + data = { + 'incoming_nodes': [], + 'outgoing_nodes': [], + } + files[org] = data + + for org in orgs: + # create intern node (request from worker A to backend A) + secret = generate_secret() + files[org]['outgoing_nodes'].append({ + 'node_id': org, + 'secret': secret + }) + files[org]['incoming_nodes'].append({ + 'node_id': org, + 'secret': secret + }) + + for other_org in filter(lambda x: x != org, orgs): + # outgoing from server B to server A share same secret as incoming from server B in server A + secret = generate_secret() + files[other_org]['outgoing_nodes'].append({ # in server B + 'node_id': org, # to server A + 'secret': secret + }) + + files[org]['incoming_nodes'].append({ # in server A + 'node_id': other_org, # from server B + 'secret': secret + }) + + return files + + +def generate_for_orgs(orgs): + files = generate(orgs) + dir_path = os.path.dirname(os.path.realpath(__file__)) + nodes_path = os.path.join(dir_path, 'nodes') + os.makedirs(nodes_path, exist_ok=True) + for k, v in files.items(): + filepath = os.path.join(nodes_path, f'{k}.json') + with open(filepath, 'w') as f: + f.write(json.dumps(v, indent=4)) + + +if __name__ == '__main__': + orgs = ['owkinMSP', 'chu-nantesMSP', 'clbMSP'] # TODO should be discovered by discovery service + + generate_for_orgs(orgs) diff --git a/backend/node/management/commands/create_incoming_node.py b/backend/node/management/commands/create_incoming_node.py new file mode 100644 index 000000000..6053820e2 --- /dev/null +++ b/backend/node/management/commands/create_incoming_node.py @@ -0,0 +1,23 @@ +from django.core.management.base import BaseCommand +from node.models import Node, IncomingNode + + +class Command(BaseCommand): + help = 'Create a new incoming node' + + def add_arguments(self, parser): + parser.add_argument('node_id') + parser.add_argument('secret', nargs='?', default=Node.generate_secret()) + + def handle(self, *args, **options): + if IncomingNode.objects.filter(node_id=options['node_id']).exists(): + self.stdout.write(self.style.NOTICE(f'node with id {options["node_id"]} already exists')) + else: + incoming_node = IncomingNode.objects.create( + node_id=options['node_id'], + secret=options['secret'], + ) + + self.stdout.write(self.style.SUCCESS('node successfully created')) + self.stdout.write(f'node_id={incoming_node.node_id}') + self.stdout.write(f'secret={incoming_node.secret}') diff --git a/backend/node/management/commands/create_outgoing_node.py b/backend/node/management/commands/create_outgoing_node.py new file mode 100644 index 000000000..c87bcefcc --- /dev/null +++ b/backend/node/management/commands/create_outgoing_node.py @@ -0,0 +1,23 @@ +from django.core.management.base import BaseCommand +from node.models import Node, OutgoingNode + + +class Command(BaseCommand): + help = 'Create a new outgoing node' + + def add_arguments(self, parser): + parser.add_argument('node_id') + parser.add_argument('secret', nargs='?', default=Node.generate_secret()) + + def handle(self, *args, **options): + if OutgoingNode.objects.filter(node_id=options['node_id']).exists(): + self.stdout.write(self.style.NOTICE(f'node with id {options["node_id"]} already exists')) + else: + outgoing_node = OutgoingNode.objects.create( + node_id=options['node_id'], + secret=options['secret'], + ) + + self.stdout.write(self.style.SUCCESS('outgoing node successfully created')) + self.stdout.write(f'node_id={outgoing_node.node_id}') + self.stdout.write(f'secret={outgoing_node.secret}') diff --git a/backend/node/management/commands/get_incoming_node.py b/backend/node/management/commands/get_incoming_node.py new file mode 100644 index 000000000..19a18dd4c --- /dev/null +++ b/backend/node/management/commands/get_incoming_node.py @@ -0,0 +1,29 @@ +from django.core.management.base import BaseCommand +from node.models import IncomingNode + + +def pretty(s1, s2): + return f'{s1.ljust(64)} | {s2.ljust(128)}' + + +class Command(BaseCommand): + help = 'Get incoming nodes' + + def add_arguments(self, parser): + parser.add_argument('node_id', nargs='?') + + def handle(self, *args, **options): + self.stdout.write(pretty("node_id", "secret")) + self.stdout.write(pretty("_" * 64, "_" * 128)) + + if options['node_id']: + try: + incoming_node = IncomingNode.objects.get(node_id=options['node_id']) + except IncomingNode.DoesNotExist: + self.stdout.write(self.style.ERROR(f'Node with id {options["node_id"]} does not exist')) + else: + self.stdout.write(self.style.SUCCESS(pretty(incoming_node.node_id, incoming_node.secret))) + else: + incoming_nodes = IncomingNode.objects.all() + for node in incoming_nodes: + self.stdout.write(self.style.SUCCESS(pretty(node.node_id, node.secret))) diff --git a/backend/node/management/commands/get_outgoing_node.py b/backend/node/management/commands/get_outgoing_node.py new file mode 100644 index 000000000..fddd53f35 --- /dev/null +++ b/backend/node/management/commands/get_outgoing_node.py @@ -0,0 +1,29 @@ +from django.core.management.base import BaseCommand +from node.models import OutgoingNode + + +def pretty(s1, s2): + return f'{s1.ljust(64)} | {s2.ljust(128)}' + + +class Command(BaseCommand): + help = 'Get outgoing nodes' + + def add_arguments(self, parser): + parser.add_argument('node_id', nargs='?') + + def handle(self, *args, **options): + self.stdout.write(pretty("node_id", "secret")) + self.stdout.write(pretty("_" * 64, "_" * 128)) + + if options['node_id']: + try: + outgoing_node = OutgoingNode.objects.get(node_id=options['node_id']) + except OutgoingNode.DoesNotExist: + self.stdout.write(self.style.ERROR(f'Node with id {options["node_id"]} does not exist')) + else: + self.stdout.write(self.style.SUCCESS(pretty(outgoing_node.node_id, outgoing_node.secret))) + else: + outgoing_nodes = OutgoingNode.objects.all() + for node in outgoing_nodes: + self.stdout.write(self.style.SUCCESS(pretty(node.node_id, node.secret))) diff --git a/backend/node/management/commands/init_nodes.py b/backend/node/management/commands/init_nodes.py new file mode 100644 index 000000000..78875b155 --- /dev/null +++ b/backend/node/management/commands/init_nodes.py @@ -0,0 +1,25 @@ +import json + +from django.core.management.base import BaseCommand +from node.models import IncomingNode, OutgoingNode + + +class Command(BaseCommand): + help = 'Load nodes from file' + + def add_arguments(self, parser): + parser.add_argument('file') + + def handle(self, *args, **options): + + filepath = options['file'] + + with open(filepath) as json_file: + data = json.load(json_file) + + for node in data['incoming_nodes']: + IncomingNode.objects.create(node_id=node['node_id'], secret=node['secret']) + self.stdout.write(self.style.SUCCESS('created incoming node')) + for node in data['outgoing_nodes']: + OutgoingNode.objects.create(node_id=node['node_id'], secret=node['secret']) + self.stdout.write(self.style.SUCCESS('created outgoing node')) diff --git a/backend/node/migrations/0001_initial.py b/backend/node/migrations/0001_initial.py new file mode 100644 index 000000000..ba8a0d8f5 --- /dev/null +++ b/backend/node/migrations/0001_initial.py @@ -0,0 +1,34 @@ +# Generated by Django 2.1.2 on 2019-08-22 11:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='IncomingNode', + fields=[ + ('node_id', models.CharField(max_length=1024, primary_key=True, serialize=False)), + ('secret', models.CharField(max_length=128)), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='OutgoingNode', + fields=[ + ('node_id', models.CharField(max_length=1024, primary_key=True, serialize=False)), + ('secret', models.CharField(max_length=128)), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/backend/node/migrations/0002_nodeuser.py b/backend/node/migrations/0002_nodeuser.py new file mode 100644 index 000000000..9eaca568a --- /dev/null +++ b/backend/node/migrations/0002_nodeuser.py @@ -0,0 +1,32 @@ +# Generated by Django 2.1.2 on 2019-09-26 09:25 + +from django.conf import settings +import django.contrib.auth.models +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('auth', '0009_alter_user_last_name_max_length'), + ('node', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='NodeUser', + fields=[ + ('user_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'user', + 'verbose_name_plural': 'users', + 'abstract': False, + }, + bases=('auth.user',), + managers=[ + ('objects', django.contrib.auth.models.UserManager()), + ], + ), + ] diff --git a/substrabac/substrapp/serializers/ledger/algo/__init__.py b/backend/node/migrations/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/algo/__init__.py rename to backend/node/migrations/__init__.py diff --git a/backend/node/models.py b/backend/node/models.py new file mode 100644 index 000000000..16deceb02 --- /dev/null +++ b/backend/node/models.py @@ -0,0 +1,40 @@ +from django.contrib.auth.hashers import make_password, check_password +from django.db import models + +import secrets + + +class Node(models.Model): + node_id = models.CharField(primary_key=True, max_length=1024, blank=False) + secret = models.CharField(max_length=128, blank=False) + + @staticmethod + def generate_secret(): + return secrets.token_hex(64) + + def set_password(self, raw_secret): + self.secret = make_password(raw_secret) + self._secret = raw_secret + + def check_password(self, raw_secret): + """ + Return a boolean of whether the raw_password was correct. Handles + hashing formats behind the scenes. + """ + def setter(raw_secret): + self.set_password(raw_secret) + # Password hash upgrades shouldn't be considered password changes. + self._secret = None + self.save(update_fields=["secret"]) + return check_password(raw_secret, self.secret, setter) + + class Meta: + abstract = True + + +class OutgoingNode(Node): + pass + + +class IncomingNode(Node): + pass diff --git a/substrabac/substrapp/serializers/ledger/datamanager/__init__.py b/backend/node/signals/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/datamanager/__init__.py rename to backend/node/signals/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/datasample/__init__.py b/backend/node/signals/node/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/datasample/__init__.py rename to backend/node/signals/node/__init__.py diff --git a/backend/node/signals/node/pre_save.py b/backend/node/signals/node/pre_save.py new file mode 100644 index 000000000..254cd664d --- /dev/null +++ b/backend/node/signals/node/pre_save.py @@ -0,0 +1,2 @@ +def node_pre_save(sender, instance, **kwargs): + instance.set_password(instance.secret) diff --git a/substrabac/substrapp/serializers/ledger/objective/__init__.py b/backend/node/tests/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/objective/__init__.py rename to backend/node/tests/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/testtuple/__init__.py b/backend/node/tests/views/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/testtuple/__init__.py rename to backend/node/tests/views/__init__.py diff --git a/backend/node/tests/views/tests_views_node.py b/backend/node/tests/views/tests_views_node.py new file mode 100644 index 000000000..aeb3d111c --- /dev/null +++ b/backend/node/tests/views/tests_views_node.py @@ -0,0 +1,48 @@ +import os +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework.test import APITestCase + +from substrapp.tests.common import AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class ModelViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + self.logger.setLevel(self.previous_level) + + def test_node_list_success(self): + url = reverse('node:node-list') + with mock.patch('node.views.node.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [{'id': 'foo'}, {'id': 'bar'}] + with mock.patch('node.views.node.get_owner') as mget_owner: + mget_owner.return_value = 'foo' + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [ + {'id': 'foo', 'isCurrent': True}, + {'id': 'bar', 'isCurrent': False} + ]) diff --git a/backend/node/urls.py b/backend/node/urls.py new file mode 100644 index 000000000..45c060f7a --- /dev/null +++ b/backend/node/urls.py @@ -0,0 +1,17 @@ +""" +node URL +""" + +from django.conf.urls import url, include +from rest_framework.routers import DefaultRouter + +from node.views import NodeViewSet + +# Create a router and register our viewsets with it. + +router = DefaultRouter() +router.register(r'node', NodeViewSet, base_name='node') + +urlpatterns = [ + url(r'^', include(router.urls)), +] diff --git a/backend/node/views/__init__.py b/backend/node/views/__init__.py new file mode 100644 index 000000000..35ac42ad6 --- /dev/null +++ b/backend/node/views/__init__.py @@ -0,0 +1,5 @@ +# encoding: utf-8 + +from .node import NodeViewSet + +__all__ = ['NodeViewSet'] diff --git a/backend/node/views/node.py b/backend/node/views/node.py new file mode 100644 index 000000000..efd1275ca --- /dev/null +++ b/backend/node/views/node.py @@ -0,0 +1,27 @@ +from rest_framework import status, mixins +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.ledger_utils import query_ledger, LedgerError +from substrapp.utils import get_owner + + +class NodeViewSet(mixins.ListModelMixin, + GenericViewSet): + ledger_query_call = 'queryNodes' + + def get_queryset(self): + return [] + + def list(self, request, *args, **kwargs): + try: + nodes = query_ledger(fcn=self.ledger_query_call) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + current_node_id = get_owner() + for node in nodes: + node.update({ + 'isCurrent': node['id'] == current_node_id, + }) + return Response(nodes, status=status.HTTP_200_OK) diff --git a/substrabac/requirements.txt b/backend/requirements.txt similarity index 71% rename from substrabac/requirements.txt rename to backend/requirements.txt index 53fd65156..3a620bb36 100644 --- a/substrabac/requirements.txt +++ b/backend/requirements.txt @@ -9,6 +9,7 @@ django-celery-results==1.0.4 django-filter==1.1.0 django-rest-swagger==2.1.2 djangorestframework==3.8.2 +djangorestframework-simplejwt==4.3.0 docker == 3.5.0 grpcio >= 1.0.1 GPUtil == 1.4.0 @@ -19,10 +20,14 @@ mock==2.0.0 psycopg2-binary==2.7.4 protobuf == 3.6.0 pycryptodomex >= 3.4.2 +pyOpenSSL == 19.0.0 pysha3 == 1.0b1 raven == 6.9.0 requests == 2.20.0 rx >= 1.5.3 sentry-sdk == 0.5.2 six >= 1.4.0 +tldextract == 2.2.1 uwsgi == 2.0.18 +zxcvbn==4.4.28 +git+git://github.com/hyperledger/fabric-sdk-py.git@36cc15021f74c11c7ae3196e380a5275c220145f # fabric-sdk-py==0.8.1 diff --git a/substrabac/substrapp/__init__.py b/backend/substrapp/__init__.py similarity index 100% rename from substrabac/substrapp/__init__.py rename to backend/substrapp/__init__.py diff --git a/substrabac/substrapp/admin.py b/backend/substrapp/admin.py similarity index 100% rename from substrabac/substrapp/admin.py rename to backend/substrapp/admin.py index 59380d882..0eb5e50bb 100644 --- a/substrabac/substrapp/admin.py +++ b/backend/substrapp/admin.py @@ -2,8 +2,8 @@ from substrapp.models import Objective, Model, DataSample, DataManager, Algo -admin.site.register(Objective) -admin.site.register(Model) -admin.site.register(DataSample) -admin.site.register(DataManager) admin.site.register(Algo) +admin.site.register(DataManager) +admin.site.register(DataSample) +admin.site.register(Model) +admin.site.register(Objective) diff --git a/substrabac/substrapp/apps.py b/backend/substrapp/apps.py similarity index 99% rename from substrabac/substrapp/apps.py rename to backend/substrapp/apps.py index c8cb6c42c..b967860d1 100644 --- a/substrabac/substrapp/apps.py +++ b/backend/substrapp/apps.py @@ -6,20 +6,21 @@ class SubstrappConfig(AppConfig): name = 'substrapp' def ready(self): + from .signals.datasample.pre_save import data_sample_pre_save + from .signals.algo.post_delete import algo_post_delete from .signals.objective.post_delete import objective_post_delete from .signals.datasample.post_delete import data_sample_post_delete from .signals.datamanager.post_delete import datamanager_post_delete from .signals.model.post_delete import model_post_delete - from .signals.datasample.pre_save import data_sample_pre_save # registering signals with the model's string label from substrapp.models import Algo, Objective, DataSample, DataManager, Model + pre_save.connect(data_sample_pre_save, sender=DataSample) + post_delete.connect(algo_post_delete, sender=Algo) post_delete.connect(objective_post_delete, sender=Objective) post_delete.connect(data_sample_post_delete, sender=DataSample) post_delete.connect(datamanager_post_delete, sender=DataManager) post_delete.connect(model_post_delete, sender=Model) - - pre_save.connect(data_sample_pre_save, sender=DataSample) diff --git a/backend/substrapp/ledger_utils.py b/backend/substrapp/ledger_utils.py new file mode 100644 index 000000000..21711f1fd --- /dev/null +++ b/backend/substrapp/ledger_utils.py @@ -0,0 +1,296 @@ +import contextlib +import functools +import json +import logging +import time + +from django.conf import settings +from rest_framework import status +from aiogrpc import RpcError + + +LEDGER = getattr(settings, 'LEDGER', None) +logger = logging.getLogger(__name__) + + +class LedgerError(Exception): + status = status.HTTP_400_BAD_REQUEST + + def __init__(self, msg): + super(LedgerError, self).__init__(msg) + self.msg = msg + + def __repr__(self): + return self.msg + + +class LedgerResponseError(LedgerError): + + @classmethod + def from_response(cls, response): + return LedgerResponseError(response['error']) + + +class LedgerConflict(LedgerResponseError): + + status = status.HTTP_409_CONFLICT + + def __init__(self, msg, pkhash): + super(LedgerConflict, self).__init__(msg) + self.pkhash = pkhash + + def __repr__(self): + return self.msg + + @classmethod + def from_response(cls, response): + pkhash = response.get('key') + if not pkhash: + return LedgerBadResponse(response['error']) + return LedgerConflict(response['error'], pkhash=pkhash) + + +class LedgerTimeout(LedgerError): + status = status.HTTP_408_REQUEST_TIMEOUT + + +class LedgerForbidden(LedgerResponseError): + status = status.HTTP_403_FORBIDDEN + + +class LedgerNotFound(LedgerResponseError): + status = status.HTTP_404_NOT_FOUND + + +class LedgerMVCCError(LedgerError): + status = status.HTTP_412_PRECONDITION_FAILED + + +class LedgerBadResponse(LedgerResponseError): + pass + + +class LedgerStatusError(LedgerError): + pass + + +STATUS_TO_EXCEPTION = { + status.HTTP_400_BAD_REQUEST: LedgerBadResponse, + status.HTTP_403_FORBIDDEN: LedgerForbidden, + status.HTTP_404_NOT_FOUND: LedgerNotFound, + status.HTTP_409_CONFLICT: LedgerConflict, +} + + +def retry_on_error(delay=1, nbtries=5, backoff=2): + def _retry(fn): + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + if not getattr(settings, 'LEDGER_CALL_RETRY', False): + return fn(*args, **kwargs) + + _delay = delay + _nbtries = nbtries + _backoff = backoff + + while True: + try: + return fn(*args, **kwargs) + except (LedgerMVCCError, LedgerTimeout, LedgerBadResponse, RpcError) as e: + + _nbtries -= 1 + if not nbtries: + raise + _delay *= _backoff + time.sleep(_delay) + logger.warning(f'Function {fn.__name__} failed: {e} retrying in {_delay}s') + + return _wrapper + return _retry + + +@contextlib.contextmanager +def get_hfc(): + loop, client = LEDGER['hfc']() + try: + yield (loop, client) + finally: + del client + loop.close() + + +def call_ledger(call_type, fcn, args=None, kwargs=None): + + with get_hfc() as (loop, client): + if not args: + args = [] + else: + args = [json.dumps(args)] + + peer = LEDGER['peer'] + requestor = LEDGER['requestor'] + + chaincode_calls = { + 'invoke': client.chaincode_invoke, + 'query': client.chaincode_query, + } + + channel_name = LEDGER['channel_name'] + chaincode_name = LEDGER['chaincode_name'] + + peers = { + 'invoke': client._peers.keys(), + 'query': [peer['name']], + } + + params = { + 'requestor': requestor, + 'channel_name': channel_name, + 'peers': peers[call_type], + 'args': args, + 'cc_name': chaincode_name, + 'fcn': fcn + } + + if kwargs is not None and isinstance(kwargs, dict): + params.update(kwargs) + + try: + response = loop.run_until_complete(chaincode_calls[call_type](**params)) + except TimeoutError as e: + raise LedgerTimeout(str(e)) + except Exception as e: + if hasattr(e, 'details') and 'access denied' in e.details(): + raise LedgerForbidden(f'Access denied for {(fcn, args)}') + + try: # get first failed response from list of protobuf ProposalResponse + response = [r for r in e.args[0] if r.response.status != 200][0].response.message + except Exception: + raise LedgerError(str(e)) + + # Deserialize the stringified json + try: + response = json.loads(response) + except json.decoder.JSONDecodeError: + if response == 'MVCC_READ_CONFLICT': + raise LedgerMVCCError(response) + elif 'cannot change status' in response: + raise LedgerStatusError(response) + else: + raise LedgerBadResponse(response) + + if response and 'error' in response: + status_code = response['status'] + exception_class = STATUS_TO_EXCEPTION.get(status_code, LedgerBadResponse) + raise exception_class.from_response(response) + + return response + + +@retry_on_error() +def query_ledger(fcn, args=None): + # careful, passing invoke parameters to query_ledger will NOT fail + return call_ledger('query', fcn=fcn, args=args) + + +@retry_on_error() +def invoke_ledger(fcn, args=None, cc_pattern=None, sync=False, only_pkhash=True): + params = { + 'wait_for_event': sync, + } + + if sync: + params['wait_for_event_timeout'] = 45 + + if cc_pattern: + params['cc_pattern'] = cc_pattern + + response = call_ledger('invoke', fcn=fcn, args=args, kwargs=params) + + if only_pkhash: + return {'pkhash': response.get('key', response.get('keys'))} + else: + return response + + +@retry_on_error() +def get_object_from_ledger(pk, query): + return query_ledger(fcn=query, args={'key': pk}) + + +@retry_on_error() +def log_fail_tuple(tuple_type, tuple_key, err_msg): + err_msg = str(err_msg).replace('"', "'").replace('\\', "").replace('\\n', "")[:200] + + fail_type = 'logFailTrain' if tuple_type == 'traintuple' else 'logFailTest' + + return invoke_ledger( + fcn=fail_type, + args={ + 'key': tuple_key, + 'log': err_msg, + }, + sync=True) + + +@retry_on_error() +def log_success_tuple(tuple_type, tuple_key, res): + if tuple_type == 'traintuple': + invoke_fcn = 'logSuccessTrain' + invoke_args = { + 'key': tuple_key, + 'outModel': { + 'hash': res["end_model_file_hash"], + 'storageAddress': res["end_model_file"], + }, + 'perf': float(res["global_perf"]), + 'log': '', + } + + elif tuple_type == 'testtuple': + invoke_fcn = 'logSuccessTest' + invoke_args = { + 'key': tuple_key, + 'perf': float(res["global_perf"]), + 'log': '', + } + + else: + raise NotImplementedError() + + return invoke_ledger(fcn=invoke_fcn, args=invoke_args, sync=True) + + +@retry_on_error() +def log_start_tuple(tuple_type, tuple_key): + start_type = None + + if tuple_type == 'traintuple': + start_type = 'logStartTrain' + elif tuple_type == 'testtuple': + start_type = 'logStartTest' + else: + raise NotImplementedError() + + try: + invoke_ledger( + fcn=start_type, + args={'key': tuple_key}, + sync=True) + except LedgerTimeout: + pass + + +@retry_on_error() +def query_tuples(tuple_type, data_owner): + data = query_ledger( + fcn="queryFilter", + args={ + 'indexName': f'{tuple_type}~worker~status', + 'attributes': f'{data_owner},todo' + } + ) + + data = [] if data is None else data + + return data diff --git a/substrabac/substrapp/management/commands/bulkcreatedatasample.py b/backend/substrapp/management/commands/bulkcreatedatasample.py similarity index 94% rename from substrabac/substrapp/management/commands/bulkcreatedatasample.py rename to backend/substrapp/management/commands/bulkcreatedatasample.py index d1dd6d2fb..ff43f6b60 100644 --- a/substrabac/substrapp/management/commands/bulkcreatedatasample.py +++ b/backend/substrapp/management/commands/bulkcreatedatasample.py @@ -28,7 +28,8 @@ def __init__(self, msg, data): # check if not already in data sample list def check(file_or_path, pkhash, data_sample): - err_msg = 'Your data sample archives/paths contain same files leading to same pkhash, please review the content of your achives/paths. %s and %s are the same' + err_msg = 'Your data sample archives/paths contain same files leading to same pkhash, ' \ + 'please review the content of your achives/paths. %s and %s are the same' for x in data_sample: if pkhash == x['pkhash']: if 'file' in x: @@ -101,7 +102,7 @@ def bulk_create_data_sample(data): class Command(BaseCommand): - help = ''' + help = ''' # noqa Bulk create data sample paths is a list of archives or paths to directories python ./manage.py bulkcreatedatasample '{"paths": ["./data1.zip", "./data2.zip", "./train/data", "./train/data2"], "data_manager_keys": ["9a832ed6cee6acf7e33c3acffbc89cebf10ef503b690711bdee048b873daf528"], "test_only": false}' @@ -119,11 +120,11 @@ def handle(self, *args, **options): args = options['data'] try: data = json.loads(args) - except: + except Exception: try: with open(args, 'r') as f: data = json.load(f) - except: + except Exception: raise CommandError('Invalid args. Please review help') else: if not isinstance(data, dict): @@ -145,5 +146,6 @@ def handle(self, *args, **options): except Exception as e: self.stderr.write(str(e)) else: - msg = f'Successfully added data samples via bulk with status code {st} and data: {json.dumps(res, indent=4)}' + msg = f'Successfully added data samples via bulk with status code {st} and data: ' \ + f'{json.dumps(res, indent=4)}' self.stdout.write(self.style.SUCCESS(msg)) diff --git a/substrabac/substrapp/management/commands/createdataset.py b/backend/substrapp/management/commands/createdataset.py similarity index 89% rename from substrabac/substrapp/management/commands/createdataset.py rename to backend/substrapp/management/commands/createdataset.py index c159dcf2d..23e3edbd6 100644 --- a/substrabac/substrapp/management/commands/createdataset.py +++ b/backend/substrapp/management/commands/createdataset.py @@ -19,13 +19,13 @@ def path_leaf(path): class Command(BaseCommand): - help = ''' + help = ''' # noqa create dataset - python ./manage.py createdataset '{"data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "objective_keys": []}, "data_samples": {"paths": ["./data.zip", "./train/data"], "test_only": false}}' + python ./manage.py createdataset '{"data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "objective_keys": [], "permissions": {"public": True, "authorized_ids": []}}, "data_samples": {"paths": ["./data.zip", "./train/data"], "test_only": false}}' python ./manage.py createdataset dataset.json # datamanager.json: # objective_keys are optional - # {"data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "objective_keys": []}, "data_samples": {"paths": ["./data.zip", "./train/data"], "test_only": false}} + # {"data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "objective_keys": [], "permissions": {"public": True, "authorized_ids": []}}, "data_samples": {"paths": ["./data.zip", "./train/data"], "test_only": false}} ''' def add_arguments(self, parser): @@ -37,11 +37,11 @@ def handle(self, *args, **options): args = options['data_input'] try: data_input = json.loads(args) - except: + except Exception: try: with open(args, 'r') as f: data_input = json.load(f) - except: + except Exception: raise CommandError('Invalid args. Please review help') else: if not isinstance(data_input, dict): @@ -58,6 +58,8 @@ def handle(self, *args, **options): return self.stderr.write('Please provide a data_opener to your data_manager') if 'description' not in data_manager: return self.stderr.write('Please provide a description to your data_manager') + if 'permissions' not in data_manager: + return self.stderr.write('Please provide permissions to your data_manager') data_samples = data_input.get('data_samples', None) if data_samples is None: @@ -98,7 +100,7 @@ def handle(self, *args, **options): # init ledger serializer ledger_serializer = LedgerDataManagerSerializer( data={'name': data_manager['name'], - 'permissions': 'all', # forced, TODO changed when permissions are available + 'permissions': data_manager['permissions'], 'type': data_manager['type'], 'objective_keys': data_manager.get('objective_keys', []), 'instance': instance}, @@ -120,7 +122,8 @@ def handle(self, *args, **options): else: d = dict(serializer.data) d.update(res) - msg = f'Successfully added datamanager with status code {st} and result: {json.dumps(res, indent=4)}' + msg = f'Successfully added datamanager with status code {st} and result: ' \ + f'{json.dumps(res, indent=4)}' self.stdout.write(self.style.SUCCESS(msg)) # Try to add data even if datamanager creation failed diff --git a/substrabac/substrapp/management/commands/createobjective.py b/backend/substrapp/management/commands/createobjective.py similarity index 84% rename from substrabac/substrapp/management/commands/createobjective.py rename to backend/substrapp/management/commands/createobjective.py index b50f62353..4835d5433 100644 --- a/substrabac/substrapp/management/commands/createobjective.py +++ b/backend/substrapp/management/commands/createobjective.py @@ -5,11 +5,10 @@ from django.core.management.base import BaseCommand, CommandError from rest_framework import status -from substrapp.management.commands.bulkcreatedatasample import \ - bulk_create_data_sample, InvalidException +from substrapp.management.commands.bulkcreatedatasample import bulk_create_data_sample, InvalidException from substrapp.management.utils.localRequest import LocalRequest -from substrapp.serializers import DataManagerSerializer, LedgerDataManagerSerializer, \ - LedgerObjectiveSerializer, ObjectiveSerializer +from substrapp.serializers import (DataManagerSerializer, LedgerDataManagerSerializer, + LedgerObjectiveSerializer, ObjectiveSerializer) from substrapp.utils import get_hash from substrapp.views.datasample import LedgerException @@ -20,12 +19,12 @@ def path_leaf(path): class Command(BaseCommand): - help = ''' + help = ''' # noqa create objective - python ./manage.py createobjective '{"objective": {"name": "foo", "metrics_name": "accuracy", "metrics": "./metrics.py", "description": "./description.md"}, "data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo"}, "data_samples": {"paths": ["./data.zip", "./train/data"]}}' + python ./manage.py createobjective '{"objective": {"name": "foo", "metrics_name": "accuracy", "metrics": "./metrics.py", "description": "./description.md", "permissions": {"public": True, "authorized_ids": []}}, "data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "permissions": {"public": True, "authorized_ids": []}, "data_samples": {"paths": ["./data.zip", "./train/data"]}}' python ./manage.py createobjective objective.json # objective.json: - # {"objective": {"name": "foo", "metrics_name": "accuracy", "metrics": "./metrics.py", "description": "./description.md"}, "data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo"}, "data_samples": {"paths": ["./data.zip", "./train/data"]}} + # {"objective": {"name": "foo", "metrics_name": "accuracy", "metrics": "./metrics.py", "description": "./description.md", "permissions": {"public": True, "authorized_ids": []}, "data_manager": {"name": "foo", "data_opener": "./opener.py", "description": "./description.md", "type": "foo", "permissions": {"public": True, "authorized_ids": []}, "data_samples": {"paths": ["./data.zip", "./train/data"]}} ''' def add_arguments(self, parser): @@ -37,11 +36,11 @@ def handle(self, *args, **options): args = options['data_input'] try: data_input = json.loads(args) - except: + except Exception: try: with open(args, 'r') as f: data_input = json.load(f) - except: + except Exception: raise CommandError('Invalid args. Please review help') else: if not isinstance(data_input, dict): @@ -62,6 +61,9 @@ def handle(self, *args, **options): if 'description' not in data_manager: return self.stderr.write( 'Please provide a description to your data_manager') + if 'permissions' not in data_manager: + return self.stderr.write( + 'Please provide permissions to your data_manager') # get data and check data_samples = data_input.get('data_samples', None) @@ -85,6 +87,9 @@ def handle(self, *args, **options): if 'description' not in objective: return self.stderr.write( 'Please provide a description to your objective') + if 'permissions' not in objective: + return self.stderr.write( + 'Please provide permissions to your objective') # by default data need to be test_only data_samples['test_only'] = True @@ -120,7 +125,7 @@ def handle(self, *args, **options): # init ledger serializer ledger_serializer = LedgerDataManagerSerializer( data={'name': data_manager['name'], - 'permissions': 'all', # forced, TODO changed when permissions are available + 'permissions': data_manager['permissions'], 'type': data_manager['type'], 'instance': instance}, context={'request': LocalRequest()}) @@ -140,7 +145,8 @@ def handle(self, *args, **options): else: d = dict(serializer.data) d.update(res) - msg = f'Successfully added datamanager with status code {st} and result: {json.dumps(res, indent=4)}' + msg = f'Successfully added datamanager with status code {st} and result: ' \ + f'{json.dumps(res, indent=4)}' self.stdout.write(self.style.SUCCESS(msg)) # Try to add data even if datamanager creation failed @@ -162,7 +168,8 @@ def handle(self, *args, **options): except Exception as e: self.stderr.write(str(e)) else: - msg = f'Successfully bulk added data samples with status code {st} and result: {json.dumps(res_data, indent=4)}' + msg = f'Successfully bulk added data samples with status code {st} and result: ' \ + f'{json.dumps(res_data, indent=4)}' self.stdout.write(self.style.SUCCESS(msg)) data_sample_pkhashes = [x['pkhash'] for x in res_data] @@ -202,8 +209,7 @@ def handle(self, *args, **options): # init ledger serializer ledger_serializer = LedgerObjectiveSerializer( data={'name': objective['name'], - 'permissions': 'all', - # forced, TODO changed when permissions are available + 'permissions': objective['permissions'], 'metrics_name': objective['metrics_name'], 'test_data_sample_keys': objective.get('test_data_sample_keys', []), 'test_data_manager_key': objective.get('test_data_manager_key', ''), @@ -228,5 +234,6 @@ def handle(self, *args, **options): else: d = dict(serializer.data) d.update(res) - msg = f'Successfully added objective with status code {st} and result: {json.dumps(res, indent=4)}' + msg = f'Successfully added objective with status code {st} and result: ' \ + f'{json.dumps(res, indent=4)}' self.stdout.write(self.style.SUCCESS(msg)) diff --git a/substrabac/substrapp/tests/tests_bulkcreatedatasample.py b/backend/substrapp/management/tests/tests_bulkcreatedatasample.py similarity index 84% rename from substrabac/substrapp/tests/tests_bulkcreatedatasample.py rename to backend/substrapp/management/tests/tests_bulkcreatedatasample.py index 9b0f49715..6048bbddc 100644 --- a/substrabac/substrapp/tests/tests_bulkcreatedatasample.py +++ b/backend/substrapp/management/tests/tests_bulkcreatedatasample.py @@ -25,7 +25,6 @@ @override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') @override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) class BulkCreateDataSampleTestCase(TestCase): @@ -50,26 +49,25 @@ def setUp(self): self.data_sample_file, self.data_sample_file_filename = get_sample_zip_data_sample() def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_bulkcreatedatasample(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) - data_path2 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) + data_path1 = os.path.normpath( + os.path.join(dir_path, '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + data_path2 = os.path.normpath( + os.path.join(dir_path, '../../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1, data_path2], 'data_manager_keys': data_manager_keys, 'test_only': False} + # dir hash pkhash1 = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' pkhash2 = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' @@ -106,7 +104,8 @@ def test_bulkcreatedatasample(self): } ] data = json.dumps(out_data, indent=4) - wanted_output = f'Successfully added data samples via bulk with status code {status.HTTP_201_CREATED} and data: {data}' + wanted_output = f'Successfully added data samples via bulk with status code ' \ + f'{status.HTTP_201_CREATED} and data: {data}' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -115,16 +114,17 @@ def test_bulkcreatedatasample_path(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/train/0024308')) + data_path1 = os.path.normpath( + os.path.join(dir_path, '../../../fixtures/chunantes/datasamples/train/0024308')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1], 'data_manager_keys': data_manager_keys, 'test_only': False} - pkhash1 = 'e3644123451975be20909fcfd9c664a0573d9bfe04c5021625412d78c3536f1c' + pkhash1 = get_hash(data_path1) with patch.object(DataManager.objects, 'filter') as mdatamanager, \ patch.object(LedgerDataSampleSerializer, 'create') as mcreate: @@ -160,7 +160,8 @@ def test_bulkcreatedatasample_path(self): }, ] data = json.dumps(out_data, indent=4) - wanted_output = f'Successfully added data samples via bulk with status code {status.HTTP_201_CREATED} and data: {data}' + wanted_output = f'Successfully added data samples via bulk with status code ' \ + f'{status.HTTP_201_CREATED} and data: {data}' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -170,15 +171,16 @@ def test_bulkcreatedatasample_original_path(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/train/0024308')) + '../../../fixtures/chunantes/datasamples/train/0024308')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1], 'data_manager_keys': data_manager_keys, 'test_only': False} - pkhash1 = 'e3644123451975be20909fcfd9c664a0573d9bfe04c5021625412d78c3536f1c' + pkhash1 = get_hash(data_path1) with patch.object(DataManager.objects, 'filter') as mdatamanager, \ patch.object(LedgerDataSampleSerializer, 'create') as mcreate: @@ -213,7 +215,8 @@ def test_bulkcreatedatasample_original_path(self): }, ] data = json.dumps(out_data, indent=4) - wanted_output = f'Successfully added data samples via bulk with status code {status.HTTP_201_CREATED} and data: {data}' + wanted_output = f'Successfully added data samples via bulk with status code ' \ + f'{status.HTTP_201_CREATED} and data: {data}' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -223,16 +226,18 @@ def test_bulkcreatedatasample_path_and_files(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/train/0024308')) + '../../../fixtures/chunantes/datasamples/train/0024308')) data_path2 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) + '../../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1, data_path2], 'data_manager_keys': data_manager_keys, 'test_only': False} + # dir hash pkhash1 = 'e3644123451975be20909fcfd9c664a0573d9bfe04c5021625412d78c3536f1c' pkhash2 = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' @@ -275,7 +280,8 @@ def test_bulkcreatedatasample_path_and_files(self): }, ] data = json.dumps(out_data, indent=4) - wanted_output = f'Successfully added data samples via bulk with status code {status.HTTP_201_CREATED} and data: {data}' + wanted_output = f'Successfully added data samples via bulk with status code ' \ + f'{status.HTTP_201_CREATED} and data: {data}' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -285,9 +291,10 @@ def test_bulkcreatedatasample_same_on_file(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1, data_path1], 'data_manager_keys': data_manager_keys, @@ -315,7 +322,9 @@ def test_bulkcreatedatasample_same_on_file(self): output = err.getvalue().strip() - wanted_output = f'Your data sample archives/paths contain same files leading to same pkhash, please review the content of your achives/paths. {data_path1} and 0024700.zip are the same' + wanted_output = f'Your data sample archives/paths contain same files leading to same pkhash, ' \ + f'please review the content of your achives/paths. ' \ + f'{data_path1} and 0024700.zip are the same' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -325,9 +334,10 @@ def test_bulkcreatedatasample_same_on_path(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/train/0024308')) + '../../../fixtures/chunantes/datasamples/train/0024308')) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'paths': [data_path1, data_path1], 'data_manager_keys': data_manager_keys, @@ -355,7 +365,9 @@ def test_bulkcreatedatasample_same_on_path(self): output = err.getvalue().strip() - wanted_output = f'Your data sample archives/paths contain same files leading to same pkhash, please review the content of your achives/paths. {data_path1} and {data_path1} are the same' + wanted_output = f'Your data sample archives/paths contain same files leading to same pkhash, ' \ + f'please review the content of your achives/paths. ' \ + f'{data_path1} and {data_path1} are the same' self.assertEqual(wanted_output, output) finally: sys.stdout = saved_stdout @@ -416,7 +428,8 @@ def test_bulkcreatedatasample_invalid_datamanager(self): output = err.getvalue().strip() - wanted_output = "One or more datamanager keys provided do not exist in local substrabac database. Please create them before. DataManager keys: ['bar']" + wanted_output = "One or more datamanager keys provided do not exist in local database. "\ + "Please create them before. DataManager keys: ['bar']" self.assertEqual(wanted_output, output) @@ -439,7 +452,8 @@ def test_bulkcreatedatasample_not_array_datamanager(self): def test_bulkcreatedatasample_datamanager_do_not_exist(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../fixtures/owkin/datamanagers/datamanager0/opener.py'))] data = {'files': ['./foo'], 'data_manager_keys': data_manager_keys, @@ -451,7 +465,8 @@ def test_bulkcreatedatasample_datamanager_do_not_exist(self): output = err.getvalue().strip() - wanted_output = f"One or more datamanager keys provided do not exist in local substrabac database. Please create them before. DataManager keys: {data_manager_keys}" + wanted_output = f"One or more datamanager keys provided do not exist in local database. " \ + f"Please create them before. DataManager keys: {data_manager_keys}" self.assertEqual(wanted_output, output) @@ -472,8 +487,8 @@ def test_bulkcreatedatasample_invalid_file(self): def test_bulkcreatedatasample_invalid_serializer(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + data_path1 = os.path.normpath( + os.path.join(dir_path, '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) data = {'paths': [data_path1], 'data_manager_keys': [self.datamanager.pk], @@ -485,7 +500,7 @@ def test_bulkcreatedatasample_invalid_serializer(self): with patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ patch.object(os.path, 'exists') as mexists, \ patch('substrapp.management.commands.bulkcreatedatasample.open', - mock_open(read_data=self.data_sample_file.read())) as mopen, \ + mock_open(read_data=self.data_sample_file.read())), \ patch( 'substrapp.management.commands.bulkcreatedatasample.DataSampleSerializer', spec=True) as mDataSampleSerializer: @@ -507,7 +522,7 @@ def test_bulkcreatedatasample_invalid_serializer(self): def test_bulkcreatedatasample_408(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) data = {'paths': [data_path1], 'data_manager_keys': [self.datamanager.pk], @@ -519,7 +534,7 @@ def test_bulkcreatedatasample_408(self): with patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ patch.object(os.path, 'exists') as mexists, \ patch('substrapp.management.commands.bulkcreatedatasample.open', - mock_open(read_data=self.data_sample_file.read())) as mopen, \ + mock_open(read_data=self.data_sample_file.read())), \ patch( 'substrapp.management.commands.bulkcreatedatasample.DataSampleSerializer', spec=True) as mDataSampleSerializer, \ @@ -546,7 +561,7 @@ def test_bulkcreatedatasample_408(self): def test_bulkcreatedatasample_ledger_400(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) data = {'paths': [data_path1], 'data_manager_keys': [self.datamanager.pk], @@ -558,7 +573,7 @@ def test_bulkcreatedatasample_ledger_400(self): with patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ patch.object(os.path, 'exists') as mexists, \ patch('substrapp.management.commands.bulkcreatedatasample.open', - mock_open(read_data=self.data_sample_file.read())) as mopen, \ + mock_open(read_data=self.data_sample_file.read())), \ patch( 'substrapp.management.commands.bulkcreatedatasample.DataSampleSerializer', spec=True) as mDataSampleSerializer, \ @@ -585,7 +600,7 @@ def test_bulkcreatedatasample_ledger_400(self): def test_bulkcreatedatasample_400(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) data = {'paths': [data_path1], 'data_manager_keys': [self.datamanager.pk], @@ -597,7 +612,7 @@ def test_bulkcreatedatasample_400(self): with patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ patch.object(os.path, 'exists') as mexists, \ patch('substrapp.management.commands.bulkcreatedatasample.open', - mock_open(read_data=self.data_sample_file.read())) as mopen, \ + mock_open(read_data=self.data_sample_file.read())), \ patch( 'substrapp.management.commands.bulkcreatedatasample.DataSampleSerializer', spec=True) as mDataSampleSerializer, \ diff --git a/substrabac/substrapp/tests/tests_createdatamanager.py b/backend/substrapp/management/tests/tests_createdatamanager.py similarity index 66% rename from substrabac/substrapp/tests/tests_createdatamanager.py rename to backend/substrapp/management/tests/tests_createdatamanager.py index 5592b9131..d3575ad0c 100644 --- a/substrabac/substrapp/tests/tests_createdatamanager.py +++ b/backend/substrapp/management/tests/tests_createdatamanager.py @@ -17,7 +17,6 @@ @override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') @override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) class CreateDataManagerTestCase(TestCase): @@ -26,24 +25,21 @@ def setUp(self): os.makedirs(MEDIA_ROOT) def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_createdatamanager(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) - data_path2 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) + data_path1 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + data_path2 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) - datamanager_opener_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) - datamanager_description_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/description.md')) + datamanager_opener_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) + datamanager_description_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/description.md')) data = { 'data_manager': { @@ -68,15 +64,13 @@ def test_createdatamanager(self): 'substrapp.views.datasample.DataSampleViewSet.check_datamanagers') as mcheck_datamanagers: mdatamanagercreate.return_value = ({ - 'pkhash': datamanager_pk, - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': datamanager_pk, + 'validated': True + }, status.HTTP_201_CREATED) mdatacreate.return_value = ({ - 'pkhash': [pkhash1, pkhash2], - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': [pkhash1, pkhash2], + 'validated': True + }, status.HTTP_201_CREATED) mcheck_datamanagers.return_value = True saved_stdout = sys.stdout @@ -89,9 +83,9 @@ def test_createdatamanager(self): output = out.getvalue().strip() datamanager_out = { - "pkhash": datamanager_pk, - "validated": True - } + "pkhash": datamanager_pk, + "validated": True + } data_out = [ { @@ -108,9 +102,12 @@ def test_createdatamanager(self): datamanager = json.dumps(datamanager_out, indent=4) data = json.dumps(data_out, indent=4) - datamanager_wanted_output = f'Successfully added datamanager with status code {status.HTTP_201_CREATED} and result: {datamanager}' - data_wanted_output = f'Successfully bulk added data samples with status code {status.HTTP_201_CREATED} and result: {data}' - self.assertEqual(output, f'{datamanager_wanted_output}\nWill add data to this datamanager now\n{data_wanted_output}') + datamanager_wanted_output = f'Successfully added datamanager with status code ' \ + f'{status.HTTP_201_CREATED} and result: {datamanager}' + data_wanted_output = f'Successfully bulk added data samples with status code ' \ + f'{status.HTTP_201_CREATED} and result: {data}' + self.assertEqual(output, f'{datamanager_wanted_output}\nWill add data to this datamanager now' + f'\n{data_wanted_output}') finally: sys.stdout = saved_stdout @@ -118,15 +115,15 @@ def test_createdatamanager_ko_409(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) - data_path2 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) + data_path1 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + data_path2 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) - datamanager_opener_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) - datamanager_description_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/description.md')) + datamanager_opener_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) + datamanager_description_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/description.md')) data = { 'data_manager': { @@ -150,14 +147,12 @@ def test_createdatamanager_ko_409(self): 'substrapp.views.datasample.DataSampleViewSet.check_datamanagers') as mcheck_datamanagers: mdatamanagercreate.return_value = ({ - 'message': 'datamanager already exists', - }, - status.HTTP_409_CONFLICT) + 'message': 'datamanager already exists', + }, status.HTTP_409_CONFLICT) mdatacreate.return_value = ({ - 'pkhash': [pkhash1, pkhash2], - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': [pkhash1, pkhash2], + 'validated': True + }, status.HTTP_201_CREATED) mcheck_datamanagers.return_value = True saved_stdout = sys.stdout @@ -173,8 +168,8 @@ def test_createdatamanager_ko_409(self): err_output = err.getvalue().strip() datamanager_out = { - "message": 'datamanager already exists', - } + "message": 'datamanager already exists', + } data_out = [ { @@ -191,7 +186,8 @@ def test_createdatamanager_ko_409(self): datamanager = json.dumps(datamanager_out, indent=2) data = json.dumps(data_out, indent=4) - data_wanted_output = f'Successfully bulk added data samples with status code {status.HTTP_201_CREATED} and result: {data}' + data_wanted_output = f'Successfully bulk added data samples with status code ' \ + f'{status.HTTP_201_CREATED} and result: {data}' self.assertEqual(output, f'Will add data to this datamanager now\n{data_wanted_output}') self.assertEqual(err_output, datamanager) finally: diff --git a/substrabac/substrapp/tests/tests_createobjective.py b/backend/substrapp/management/tests/tests_createobjective.py similarity index 60% rename from substrabac/substrapp/tests/tests_createobjective.py rename to backend/substrapp/management/tests/tests_createobjective.py index 1005011f5..4e1bf8b62 100644 --- a/substrabac/substrapp/tests/tests_createobjective.py +++ b/backend/substrapp/management/tests/tests_createobjective.py @@ -16,7 +16,6 @@ @override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') @override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) class CreateObjectiveTestCase(TestCase): @@ -26,28 +25,25 @@ def setUp(self): self.maxDiff = None def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_createobjective(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path1 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) - data_path2 = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) + data_path1 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample1/0024700.zip')) + data_path2 = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datasamples/datasample0/0024899.zip')) - datamanager_opener_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) - datamanager_description_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/description.md')) - objective_metrics_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/objectives/objective0/metrics.py')) - objective_description_path = os.path.normpath(os.path.join(dir_path, - '../../fixtures/chunantes/objectives/objective0/description.md')) + datamanager_opener_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/opener.py')) + datamanager_description_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/datamanagers/datamanager0/description.md')) + objective_metrics_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/objectives/objective0/metrics.py')) + objective_description_path = os.path.normpath(os.path.join( + dir_path, '../../../fixtures/chunantes/objectives/objective0/description.md')) data = { 'objective': { @@ -69,7 +65,7 @@ def test_createobjective(self): } objective_pk = 'd5002e1cd50bd5de5341df8a7b7d11b6437154b3b08f531c9b8f93889855c66f' - datamanager_pk = '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7' + datamanager_pk = '8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca' pkhash1 = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' pkhash2 = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' @@ -79,20 +75,17 @@ def test_createobjective(self): patch('substrapp.views.datasample.DataSampleViewSet.check_datamanagers') as mcheck_datamanagers: mobjectivecreate.return_value = ({ - 'pkhash': objective_pk, - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': objective_pk, + 'validated': True + }, status.HTTP_201_CREATED) mdatamanagercreate.return_value = ({ - 'pkhash': datamanager_pk, - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': datamanager_pk, + 'validated': True + }, status.HTTP_201_CREATED) mdatacreate.return_value = ({ - 'pkhash': [pkhash1, pkhash2], - 'validated': True - }, - status.HTTP_201_CREATED) + 'pkhash': [pkhash1, pkhash2], + 'validated': True + }, status.HTTP_201_CREATED) mcheck_datamanagers.return_value = True @@ -111,9 +104,9 @@ def test_createobjective(self): } datamanager_out = { - "pkhash": datamanager_pk, - "validated": True - } + "pkhash": datamanager_pk, + "validated": True + } data_out = [ { @@ -131,9 +124,14 @@ def test_createobjective(self): datamanager = json.dumps(datamanager_out, indent=4) data = json.dumps(data_out, indent=4) objective = json.dumps(objective_out, indent=4) - datamanager_wanted_output = f'Successfully added datamanager with status code {status.HTTP_201_CREATED} and result: {datamanager}' - data_wanted_output = f'Successfully bulk added data samples with status code {status.HTTP_201_CREATED} and result: {data}' - objective_wanted_output = f'Successfully added objective with status code {status.HTTP_201_CREATED} and result: {objective}' - self.assertEqual(output, f'{datamanager_wanted_output}\nWill add data samples to this datamanager now\n{data_wanted_output}\nWill add objective to this datamanager now\n{objective_wanted_output}') + datamanager_wanted_output = f'Successfully added datamanager with status code ' \ + f'{status.HTTP_201_CREATED} and result: {datamanager}' + data_wanted_output = f'Successfully bulk added data samples with status code ' \ + f'{status.HTTP_201_CREATED} and result: {data}' + objective_wanted_output = f'Successfully added objective with status code ' \ + f'{status.HTTP_201_CREATED} and result: {objective}' + self.assertEqual(output, f'{datamanager_wanted_output}\nWill add data samples to this datamanager now' + f'\n{data_wanted_output}\nWill add objective to this datamanager now' + f'\n{objective_wanted_output}') finally: sys.stdout = saved_stdout diff --git a/substrabac/substrapp/serializers/ledger/traintuple/__init__.py b/backend/substrapp/management/utils/__init__.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/traintuple/__init__.py rename to backend/substrapp/management/utils/__init__.py diff --git a/backend/substrapp/management/utils/localRequest.py b/backend/substrapp/management/utils/localRequest.py new file mode 100644 index 000000000..eabc23df4 --- /dev/null +++ b/backend/substrapp/management/utils/localRequest.py @@ -0,0 +1,12 @@ +from django.conf import settings +from urllib.parse import urlparse + + +class LocalRequest(object): + + def is_secure(self): + return not getattr(settings, 'DEBUG') + + def get_host(self): + # remove protocol (http/https) from default domain + return urlparse(getattr(settings, 'DEFAULT_DOMAIN')).netloc diff --git a/substrabac/substrapp/migrations/0001_initial.py b/backend/substrapp/migrations/0001_initial.py similarity index 100% rename from substrabac/substrapp/migrations/0001_initial.py rename to backend/substrapp/migrations/0001_initial.py diff --git a/substrabac/substrapp/signals/__init__.py b/backend/substrapp/migrations/__init__.py similarity index 100% rename from substrabac/substrapp/signals/__init__.py rename to backend/substrapp/migrations/__init__.py diff --git a/substrabac/substrapp/models/__init__.py b/backend/substrapp/models/__init__.py similarity index 100% rename from substrabac/substrapp/models/__init__.py rename to backend/substrapp/models/__init__.py diff --git a/substrabac/substrapp/models/algo.py b/backend/substrapp/models/algo.py similarity index 100% rename from substrabac/substrapp/models/algo.py rename to backend/substrapp/models/algo.py diff --git a/substrabac/substrapp/models/datamanager.py b/backend/substrapp/models/datamanager.py similarity index 100% rename from substrabac/substrapp/models/datamanager.py rename to backend/substrapp/models/datamanager.py diff --git a/substrabac/substrapp/models/datasample.py b/backend/substrapp/models/datasample.py similarity index 87% rename from substrabac/substrapp/models/datasample.py rename to backend/substrapp/models/datasample.py index 409deab06..41d794b14 100644 --- a/substrabac/substrapp/models/datasample.py +++ b/backend/substrapp/models/datasample.py @@ -3,10 +3,6 @@ from substrapp.utils import get_hash -def upload_to(instance, filename): - return 'datasamples/{0}/{1}'.format(instance.pk, filename) - - class DataSample(models.Model): """Storage Data table""" pkhash = models.CharField(primary_key=True, max_length=64, blank=True) diff --git a/substrabac/substrapp/models/model.py b/backend/substrapp/models/model.py similarity index 100% rename from substrabac/substrapp/models/model.py rename to backend/substrapp/models/model.py diff --git a/substrabac/substrapp/models/objective.py b/backend/substrapp/models/objective.py similarity index 84% rename from substrabac/substrapp/models/objective.py rename to backend/substrapp/models/objective.py index fd4685539..69c6dcc5d 100644 --- a/substrabac/substrapp/models/objective.py +++ b/backend/substrapp/models/objective.py @@ -12,8 +12,8 @@ class Objective(TimeStamped): """Storage Objective table""" pkhash = models.CharField(primary_key=True, max_length=64, blank=True) validated = models.BooleanField(default=False, blank=True) - description = models.FileField(upload_to=upload_to, max_length=500, blank=True, null=True) # path max length to 500 instead of default 100 - metrics = models.FileField(upload_to=upload_to, max_length=500, blank=True, null=True) # path max length to 500 instead of default 100 + description = models.FileField(upload_to=upload_to, max_length=500, blank=True, null=True) + metrics = models.FileField(upload_to=upload_to, max_length=500, blank=True, null=True) def save(self, *args, **kwargs): """Use hash of description file as primary key""" diff --git a/substrabac/substrapp/serializers/__init__.py b/backend/substrapp/serializers/__init__.py similarity index 88% rename from substrabac/substrapp/serializers/__init__.py rename to backend/substrapp/serializers/__init__.py index e0f4f7577..015186c69 100644 --- a/substrabac/substrapp/serializers/__init__.py +++ b/backend/substrapp/serializers/__init__.py @@ -12,4 +12,4 @@ 'LedgerObjectiveSerializer', 'LedgerModelSerializer', 'LedgerDataSampleSerializer', 'LedgerAlgoSerializer', 'LedgerTrainTupleSerializer', 'LedgerTestTupleSerializer', - 'LedgerDataManagerSerializer'] + 'LedgerDataManagerSerializer', 'LedgerComputePlanSerializer'] diff --git a/backend/substrapp/serializers/algo.py b/backend/substrapp/serializers/algo.py new file mode 100644 index 000000000..2b0d15e98 --- /dev/null +++ b/backend/substrapp/serializers/algo.py @@ -0,0 +1,14 @@ +from rest_framework import serializers + +from libs.serializers import DynamicFieldsModelSerializer +from substrapp.models import Algo + +from substrapp.serializers.utils import FileValidator + + +class AlgoSerializer(DynamicFieldsModelSerializer): + file = serializers.FileField(validators=[FileValidator()]) + + class Meta: + model = Algo + fields = '__all__' diff --git a/substrabac/substrapp/serializers/datamanager.py b/backend/substrapp/serializers/datamanager.py similarity index 100% rename from substrabac/substrapp/serializers/datamanager.py rename to backend/substrapp/serializers/datamanager.py diff --git a/substrabac/substrapp/serializers/datasample.py b/backend/substrapp/serializers/datasample.py similarity index 97% rename from substrabac/substrapp/serializers/datasample.py rename to backend/substrapp/serializers/datasample.py index 5a99d84e1..1f3b974d6 100644 --- a/substrabac/substrapp/serializers/datasample.py +++ b/backend/substrapp/serializers/datasample.py @@ -4,7 +4,6 @@ from django.core.exceptions import ValidationError from django.core.files import File -from django.core.files.uploadedfile import InMemoryUploadedFile from rest_framework import serializers from rest_framework.serializers import raise_errors_on_nested_writes from rest_framework.utils import model_meta @@ -26,7 +25,7 @@ def __call__(self, data): try: data.file.seek(0) - except: + except Exception: raise ValidationError(self.error_messages['open']) else: try: diff --git a/substrabac/substrapp/serializers/ledger/__init__.py b/backend/substrapp/serializers/ledger/__init__.py similarity index 75% rename from substrabac/substrapp/serializers/ledger/__init__.py rename to backend/substrapp/serializers/ledger/__init__.py index 929e5c3af..327c989ad 100644 --- a/substrabac/substrapp/serializers/ledger/__init__.py +++ b/backend/substrapp/serializers/ledger/__init__.py @@ -1,14 +1,15 @@ # encoding: utf-8 from .objective.serializer import LedgerObjectiveSerializer -from .model import LedgerModelSerializer +from .model.serializer import LedgerModelSerializer from .datasample.serializer import LedgerDataSampleSerializer from .algo.serializer import LedgerAlgoSerializer from .traintuple.serializer import LedgerTrainTupleSerializer from .testtuple.serializer import LedgerTestTupleSerializer from .datamanager.serializer import LedgerDataManagerSerializer +from .computeplan.serializer import LedgerComputePlanSerializer __all__ = ['LedgerObjectiveSerializer', 'LedgerModelSerializer', 'LedgerDataSampleSerializer', 'LedgerAlgoSerializer', 'LedgerTrainTupleSerializer', 'LedgerTestTupleSerializer', - 'LedgerDataManagerSerializer'] + 'LedgerDataManagerSerializer', 'LedgerComputePlanSerializer'] diff --git a/substrabac/substrapp/signals/algo/__init__.py b/backend/substrapp/serializers/ledger/algo/__init__.py similarity index 100% rename from substrabac/substrapp/signals/algo/__init__.py rename to backend/substrapp/serializers/ledger/algo/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/algo/serializer.py b/backend/substrapp/serializers/ledger/algo/serializer.py similarity index 63% rename from substrabac/substrapp/serializers/ledger/algo/serializer.py rename to backend/substrapp/serializers/ledger/algo/serializer.py index e6d53ed1c..78365f26e 100644 --- a/substrabac/substrapp/serializers/ledger/algo/serializer.py +++ b/backend/substrapp/serializers/ledger/algo/serializer.py @@ -1,16 +1,17 @@ -from rest_framework import serializers, status +from rest_framework import serializers from django.conf import settings from rest_framework.reverse import reverse from substrapp.utils import get_hash +from substrapp.serializers.ledger.utils import PermissionsSerializer from .util import createLedgerAlgo from .tasks import createLedgerAlgoAsync class LedgerAlgoSerializer(serializers.Serializer): name = serializers.CharField(min_length=1, max_length=100) - permissions = serializers.CharField(min_length=1, max_length=60) + permissions = PermissionsSerializer() def create(self, validated_data): instance = self.initial_data.get('instance') @@ -18,27 +19,34 @@ def create(self, validated_data): permissions = validated_data.get('permissions') # TODO, create a datamigration with new Site domain name when we will know the name of the final website - # current_site = Site.objects.get_current() + host = '' + protocol = 'http://' request = self.context.get('request', None) - protocol = 'https://' if request.is_secure() else 'http://' - host = '' if request is None else request.get_host() - args = '"%(name)s", "%(algoHash)s", "%(storageAddress)s", "%(descriptionHash)s", "%(descriptionStorageAddress)s", "%(permissions)s"' % { + if request: + protocol = 'https://' if request.is_secure() else 'http://' + host = request.get_host() + + args = { 'name': name, - 'algoHash': get_hash(instance.file), + 'hash': get_hash(instance.file), 'storageAddress': protocol + host + reverse('substrapp:algo-file', args=[instance.pk]), 'descriptionHash': get_hash(instance.description), 'descriptionStorageAddress': protocol + host + reverse('substrapp:algo-description', args=[instance.pk]), - 'permissions': permissions + 'permissions': {'process': { + 'public': permissions.get('public'), + 'authorizedIDs': permissions.get('authorized_ids'), + }} } if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerAlgo(args, instance.pkhash, sync=True) + data = createLedgerAlgo(args, instance.pkhash, sync=True) else: # use a celery task, as we are in an http request transaction createLedgerAlgoAsync.delay(args, instance.pkhash) data = { - 'message': 'Algo added in local db waiting for validation. The substra network has been notified for adding this Algo' + 'message': 'Algo added in local db waiting for validation. ' + 'The substra network has been notified for adding this Algo' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/algo/tasks.py b/backend/substrapp/serializers/ledger/algo/tasks.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/algo/tasks.py rename to backend/substrapp/serializers/ledger/algo/tasks.py diff --git a/backend/substrapp/serializers/ledger/algo/util.py b/backend/substrapp/serializers/ledger/algo/util.py new file mode 100644 index 000000000..5efce85cd --- /dev/null +++ b/backend/substrapp/serializers/ledger/algo/util.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, unicode_literals + +from substrapp.models import Algo +from substrapp.serializers.ledger.utils import create_ledger_asset + + +def createLedgerAlgo(args, pkhash, sync=False): + return create_ledger_asset( + model=Algo, + fcn='registerAlgo', + args=args, + pkhash=pkhash, + sync=sync) diff --git a/substrabac/substrapp/signals/datamanager/__init__.py b/backend/substrapp/serializers/ledger/computeplan/__init__.py similarity index 100% rename from substrabac/substrapp/signals/datamanager/__init__.py rename to backend/substrapp/serializers/ledger/computeplan/__init__.py diff --git a/backend/substrapp/serializers/ledger/computeplan/serializer.py b/backend/substrapp/serializers/ledger/computeplan/serializer.py new file mode 100644 index 000000000..4f40975fe --- /dev/null +++ b/backend/substrapp/serializers/ledger/computeplan/serializer.py @@ -0,0 +1,89 @@ +from rest_framework import serializers + +from django.conf import settings + +from .util import createLedgerComputePlan +from .tasks import createLedgerComputePlanAsync + + +class ComputePlanTraintupleSerializer(serializers.Serializer): + data_manager_key = serializers.CharField(min_length=64, max_length=64) + train_data_sample_keys = serializers.ListField( + child=serializers.CharField(min_length=64, max_length=64), + min_length=1) + traintuple_id = serializers.CharField(min_length=1, max_length=64) + in_models_ids = serializers.ListField( + child=serializers.CharField(min_length=1, max_length=64), + min_length=0, + required=False) + tag = serializers.CharField(min_length=0, max_length=64, allow_blank=True, required=False) + + +class ComputePlanTesttupleSerializer(serializers.Serializer): + traintuple_id = serializers.CharField(min_length=1, max_length=64) + data_manager_key = serializers.CharField(min_length=64, max_length=64, required=False) + test_data_sample_keys = serializers.ListField( + child=serializers.CharField(min_length=64, max_length=64), + min_length=0, + required=False) + tag = serializers.CharField(min_length=0, max_length=64, allow_blank=True, required=False) + + +class LedgerComputePlanSerializer(serializers.Serializer): + algo_key = serializers.CharField(min_length=64, max_length=64) + objective_key = serializers.CharField(min_length=64, max_length=64) + traintuples = ComputePlanTraintupleSerializer(many=True) + testtuples = ComputePlanTesttupleSerializer(many=True) + + def get_args(self, data): + # convert snake case fields to camel case fields to match chaincode expected inputs + traintuples = [] + for data_traintuple in data['traintuples']: + traintuple = { + 'dataManagerKey': data_traintuple['data_manager_key'], + 'dataSampleKeys': data_traintuple['train_data_sample_keys'], + 'id': data_traintuple['traintuple_id'], + } + if 'in_models_ids' in data_traintuple: + traintuple['inModelsIDs'] = data_traintuple['in_models_ids'] + if 'tag' in data_traintuple: + traintuple['tag'] = data_traintuple['tag'] + + traintuples.append(traintuple) + + testtuples = [] + for data_testtuple in data['testtuples']: + testtuple = { + 'traintupleID': data_testtuple['traintuple_id'], + } + if 'tag' in data_testtuple: + testtuple['tag'] = data_testtuple['tag'] + if 'data_manager_key' in data_testtuple: + testtuple['dataManagerKey'] = data_testtuple['data_manager_key'] + if 'test_data_sample_keys' in data_testtuple: + testtuple['dataSampleKeys'] = data_testtuple['test_data_sample_keys'] + + testtuples.append(testtuple) + + return { + 'algoKey': data['algo_key'], + 'objectiveKey': data['objective_key'], + 'traintuples': traintuples, + 'testtuples': testtuples, + } + + def create(self, validated_data): + args = self.get_args(validated_data) + + if getattr(settings, 'LEDGER_SYNC_ENABLED'): + data = createLedgerComputePlan(args, sync=True) + else: + # use a celery task, as we are in an http request transaction + createLedgerComputePlanAsync.delay(args) + data = { + 'message': 'The substra network has been notified for adding this ComputePlan. ' + 'Please be aware you won\'t get return values from the ledger. ' + 'You will need to check manually' + } + + return data diff --git a/backend/substrapp/serializers/ledger/computeplan/tasks.py b/backend/substrapp/serializers/ledger/computeplan/tasks.py new file mode 100644 index 000000000..9146b312f --- /dev/null +++ b/backend/substrapp/serializers/ledger/computeplan/tasks.py @@ -0,0 +1,9 @@ +# Create your tasks here +from __future__ import absolute_import, unicode_literals +from celery import shared_task +from .util import createLedgerComputePlan + + +@shared_task +def createLedgerComputePlanAsync(args): + return createLedgerComputePlan(args) diff --git a/backend/substrapp/serializers/ledger/computeplan/util.py b/backend/substrapp/serializers/ledger/computeplan/util.py new file mode 100644 index 000000000..d8d8ba774 --- /dev/null +++ b/backend/substrapp/serializers/ledger/computeplan/util.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import, unicode_literals + + +from substrapp.ledger_utils import invoke_ledger + + +def createLedgerComputePlan(args, sync=False): + return invoke_ledger(fcn='createComputePlan', args=args, sync=sync, only_pkhash=False) diff --git a/substrabac/substrapp/signals/datasample/__init__.py b/backend/substrapp/serializers/ledger/datamanager/__init__.py similarity index 100% rename from substrabac/substrapp/signals/datasample/__init__.py rename to backend/substrapp/serializers/ledger/datamanager/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/datamanager/serializer.py b/backend/substrapp/serializers/ledger/datamanager/serializer.py similarity index 62% rename from substrabac/substrapp/serializers/ledger/datamanager/serializer.py rename to backend/substrapp/serializers/ledger/datamanager/serializer.py index 334365bf5..3d9f87beb 100644 --- a/substrabac/substrapp/serializers/ledger/datamanager/serializer.py +++ b/backend/substrapp/serializers/ledger/datamanager/serializer.py @@ -1,9 +1,10 @@ -from rest_framework import serializers, status +from rest_framework import serializers from django.conf import settings from rest_framework.reverse import reverse from substrapp.utils import get_hash +from substrapp.serializers.ledger.utils import PermissionsSerializer from .util import createLedgerDataManager from .tasks import createLedgerDataManagerAsync @@ -12,40 +13,48 @@ class LedgerDataManagerSerializer(serializers.Serializer): name = serializers.CharField(max_length=100) type = serializers.CharField(max_length=30) objective_key = serializers.CharField(max_length=256, allow_blank=True, required=False) - permissions = serializers.CharField(min_length=1, max_length=60) + permissions = PermissionsSerializer() def create(self, validated_data): instance = self.initial_data.get('instance') name = validated_data.get('name') - type = validated_data.get('type') + data_type = validated_data.get('type') permissions = validated_data.get('permissions') objective_key = validated_data.get('objective_key', '') # TODO, create a datamigration with new Site domain name when we will know the name of the final website - # current_site = Site.objects.get_current() + host = '' + protocol = 'http://' request = self.context.get('request', None) - protocol = 'https://' if request.is_secure() else 'http://' - host = '' if request is None else request.get_host() - args = '"%(name)s", "%(openerHash)s", "%(openerStorageAddress)s", "%(type)s", "%(descriptionHash)s", "%(descriptionStorageAddress)s", "%(objectiveKey)s", "%(permissions)s"' % { + if request: + protocol = 'https://' if request.is_secure() else 'http://' + host = request.get_host() + + args = { 'name': name, 'openerHash': get_hash(instance.data_opener), 'openerStorageAddress': protocol + host + reverse('substrapp:data_manager-opener', args=[instance.pk]), - 'type': type, + 'type': data_type, 'descriptionHash': get_hash(instance.description), - 'descriptionStorageAddress': protocol + host + reverse('substrapp:data_manager-description', args=[instance.pk]), + 'descriptionStorageAddress': protocol + host + reverse('substrapp:data_manager-description', + args=[instance.pk]), 'objectiveKey': objective_key, - 'permissions': permissions + 'permissions': {'process': { + 'public': permissions.get('public'), + 'authorizedIDs': permissions.get('authorized_ids'), + }} } if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerDataManager(args, instance.pkhash, sync=True) + data = createLedgerDataManager(args, instance.pkhash, sync=True) else: # use a celery task, as we are in an http request transaction createLedgerDataManagerAsync.delay(args, instance.pkhash) data = { - 'message': 'DataManager added in local db waiting for validation. The substra network has been notified for adding this DataManager' + 'message': 'DataManager added in local db waiting for validation. ' + 'The substra network has been notified for adding this DataManager' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/datamanager/tasks.py b/backend/substrapp/serializers/ledger/datamanager/tasks.py similarity index 99% rename from substrabac/substrapp/serializers/ledger/datamanager/tasks.py rename to backend/substrapp/serializers/ledger/datamanager/tasks.py index 87ff92843..d808c0083 100644 --- a/substrabac/substrapp/serializers/ledger/datamanager/tasks.py +++ b/backend/substrapp/serializers/ledger/datamanager/tasks.py @@ -8,6 +8,7 @@ def createLedgerDataManagerAsync(args, pkhash): return createLedgerDataManager(args, pkhash) + @shared_task def updateLedgerDataManagerAsync(args): return updateLedgerDataManager(args) diff --git a/backend/substrapp/serializers/ledger/datamanager/util.py b/backend/substrapp/serializers/ledger/datamanager/util.py new file mode 100644 index 000000000..e762dc98e --- /dev/null +++ b/backend/substrapp/serializers/ledger/datamanager/util.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import, unicode_literals + +from substrapp.models import DataManager +from substrapp.ledger_utils import invoke_ledger + +from substrapp.serializers.ledger.utils import create_ledger_asset + + +def createLedgerDataManager(args, pkhash, sync=False): + return create_ledger_asset( + model=DataManager, + fcn='registerDataManager', + args=args, + pkhash=pkhash, + sync=sync) + + +def updateLedgerDataManager(args, sync=False): + return invoke_ledger(fcn='updateDataManager', args=args, sync=sync) diff --git a/substrabac/substrapp/signals/model/__init__.py b/backend/substrapp/serializers/ledger/datasample/__init__.py similarity index 100% rename from substrabac/substrapp/signals/model/__init__.py rename to backend/substrapp/serializers/ledger/datasample/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/datasample/serializer.py b/backend/substrapp/serializers/ledger/datasample/serializer.py similarity index 68% rename from substrabac/substrapp/serializers/ledger/datasample/serializer.py rename to backend/substrapp/serializers/ledger/datasample/serializer.py index 8171ead95..891bf984a 100644 --- a/substrabac/substrapp/serializers/ledger/datasample/serializer.py +++ b/backend/substrapp/serializers/ledger/datasample/serializer.py @@ -1,6 +1,6 @@ import json -from rest_framework import serializers, status +from rest_framework import serializers from django.conf import settings @@ -18,19 +18,20 @@ def create(self, validated_data): data_manager_keys = validated_data.get('data_manager_keys') test_only = validated_data.get('test_only') - args = '"%(hashes)s", "%(dataManagerKeys)s", "%(testOnly)s"' % { - 'hashes': ','.join([x.pk for x in instances]), - 'dataManagerKeys': ','.join([x for x in data_manager_keys]), + args = { + 'hashes': [x.pk for x in instances], + 'dataManagerKeys': [x for x in data_manager_keys], 'testOnly': json.dumps(test_only), } if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerDataSample(args, [x.pk for x in instances], sync=True) + data = createLedgerDataSample(args, [x.pk for x in instances], sync=True) else: # use a celery task, as we are in an http request transaction createLedgerDataSampleAsync.delay(args, [x.pk for x in instances]) data = { - 'message': 'Data samples added in local db waiting for validation. The substra network has been notified for adding this Data' + 'message': 'Data samples added in local db waiting for validation. ' + 'The substra network has been notified for adding this Data' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/datasample/tasks.py b/backend/substrapp/serializers/ledger/datasample/tasks.py similarity index 99% rename from substrabac/substrapp/serializers/ledger/datasample/tasks.py rename to backend/substrapp/serializers/ledger/datasample/tasks.py index 51dfe0f52..f82a2b7ea 100644 --- a/substrabac/substrapp/serializers/ledger/datasample/tasks.py +++ b/backend/substrapp/serializers/ledger/datasample/tasks.py @@ -8,6 +8,7 @@ def createLedgerDataSampleAsync(args, pkhashes): return createLedgerDataSample(args, pkhashes) + @shared_task def updateLedgerDataSampleAsync(args): return updateLedgerDataSample(args) diff --git a/backend/substrapp/serializers/ledger/datasample/util.py b/backend/substrapp/serializers/ledger/datasample/util.py new file mode 100644 index 000000000..ac05b7b62 --- /dev/null +++ b/backend/substrapp/serializers/ledger/datasample/util.py @@ -0,0 +1,18 @@ +from __future__ import absolute_import, unicode_literals + +from substrapp.models import DataSample +from substrapp.ledger_utils import invoke_ledger +from substrapp.serializers.ledger.utils import create_ledger_assets + + +def createLedgerDataSample(args, pkhashes, sync=False): + return create_ledger_assets( + model=DataSample, + fcn='registerDataSample', + args=args, + pkhashes=pkhashes, + sync=sync) + + +def updateLedgerDataSample(args, sync=False): + return invoke_ledger(fcn='updateDataSample', args=args, sync=sync) diff --git a/substrabac/substrapp/signals/objective/__init__.py b/backend/substrapp/serializers/ledger/model/__init__.py similarity index 100% rename from substrabac/substrapp/signals/objective/__init__.py rename to backend/substrapp/serializers/ledger/model/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/model.py b/backend/substrapp/serializers/ledger/model/serializer.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/model.py rename to backend/substrapp/serializers/ledger/model/serializer.py diff --git a/substrabac/substrapp/tests/__init__.py b/backend/substrapp/serializers/ledger/objective/__init__.py similarity index 100% rename from substrabac/substrapp/tests/__init__.py rename to backend/substrapp/serializers/ledger/objective/__init__.py diff --git a/substrabac/substrapp/serializers/ledger/objective/serializer.py b/backend/substrapp/serializers/ledger/objective/serializer.py similarity index 68% rename from substrabac/substrapp/serializers/ledger/objective/serializer.py rename to backend/substrapp/serializers/ledger/objective/serializer.py index 372c336c1..f21862456 100644 --- a/substrabac/substrapp/serializers/ledger/objective/serializer.py +++ b/backend/substrapp/serializers/ledger/objective/serializer.py @@ -1,9 +1,10 @@ -from rest_framework import serializers, status +from rest_framework import serializers # from django.contrib.sites.models import Site from django.conf import settings from rest_framework.reverse import reverse from substrapp.utils import get_hash +from substrapp.serializers.ledger.utils import PermissionsSerializer from .util import createLedgerObjective from .tasks import createLedgerObjectiveAsync @@ -14,7 +15,7 @@ class LedgerObjectiveSerializer(serializers.Serializer): required=False) name = serializers.CharField(min_length=1, max_length=100) test_data_manager_key = serializers.CharField(max_length=256, allow_blank=True, required=False) - permissions = serializers.CharField(min_length=1, max_length=60) + permissions = PermissionsSerializer() metrics_name = serializers.CharField(min_length=1, max_length=100) def create(self, validated_data): @@ -26,29 +27,39 @@ def create(self, validated_data): test_data_sample_keys = validated_data.get('test_data_sample_keys', []) # TODO, create a datamigration with new Site domain name when we will know the name of the final website - # current_site = Site.objects.get_current() + host = '' + protocol = 'http://' request = self.context.get('request', None) - protocol = 'https://' if request.is_secure() else 'http://' - host = '' if request is None else request.get_host() - args = '"%(name)s", "%(descriptionHash)s", "%(descriptionStorageAddress)s", "%(metricsName)s", "%(metricsHash)s", "%(metricsStorageAddress)s", "%(testDataSample)s", "%(permissions)s"' % { + if request: + protocol = 'https://' if request.is_secure() else 'http://' + host = request.get_host() + + args = { 'name': name, 'descriptionHash': get_hash(instance.description), - 'descriptionStorageAddress': protocol + host + reverse('substrapp:objective-description', args=[instance.pk]), + 'descriptionStorageAddress': protocol + host + reverse('substrapp:objective-description', args=[instance.pk]), # noqa 'metricsName': metrics_name, 'metricsHash': get_hash(instance.metrics), 'metricsStorageAddress': protocol + host + reverse('substrapp:objective-metrics', args=[instance.pk]), - 'testDataSample': f'{test_data_manager_key}:{",".join([x for x in test_data_sample_keys])}', - 'permissions': permissions + 'testDataset': { + 'dataManagerKey': test_data_manager_key, + 'dataSampleKeys': test_data_sample_keys, + }, + 'permissions': {'process': { + 'public': permissions.get('public'), + 'authorizedIDs': permissions.get('authorized_ids'), + }} } if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerObjective(args, instance.pkhash, sync=True) + data = createLedgerObjective(args, instance.pkhash, sync=True) else: # use a celery task, as we are in an http request transaction createLedgerObjectiveAsync.delay(args, instance.pkhash) data = { - 'message': 'Objective added in local db waiting for validation. The substra network has been notified for adding this Objective' + 'message': 'Objective added in local db waiting for validation. ' + 'The substra network has been notified for adding this Objective' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/objective/tasks.py b/backend/substrapp/serializers/ledger/objective/tasks.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/objective/tasks.py rename to backend/substrapp/serializers/ledger/objective/tasks.py diff --git a/backend/substrapp/serializers/ledger/objective/util.py b/backend/substrapp/serializers/ledger/objective/util.py new file mode 100644 index 000000000..517cba558 --- /dev/null +++ b/backend/substrapp/serializers/ledger/objective/util.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, unicode_literals + +from substrapp.models import Objective +from substrapp.serializers.ledger.utils import create_ledger_asset + + +def createLedgerObjective(args, pkhash, sync=False): + return create_ledger_asset( + model=Objective, + fcn='registerObjective', + args=args, + pkhash=pkhash, + sync=sync) diff --git a/backend/substrapp/serializers/ledger/testtuple/__init__.py b/backend/substrapp/serializers/ledger/testtuple/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/serializers/ledger/testtuple/serializer.py b/backend/substrapp/serializers/ledger/testtuple/serializer.py similarity index 76% rename from substrabac/substrapp/serializers/ledger/testtuple/serializer.py rename to backend/substrapp/serializers/ledger/testtuple/serializer.py index c9f5ccbad..30a62a85d 100644 --- a/substrabac/substrapp/serializers/ledger/testtuple/serializer.py +++ b/backend/substrapp/serializers/ledger/testtuple/serializer.py @@ -1,4 +1,4 @@ -from rest_framework import serializers, status +from rest_framework import serializers from django.conf import settings @@ -20,25 +20,27 @@ def get_args(self, validated_data): test_data_sample_keys = validated_data.get('test_data_sample_keys', []) tag = validated_data.get('tag', '') - args = '"%(traintupleKey)s", "%(dataManagerKey)s", "%(dataSampleKeys)s", "%(tag)s"' % { + args = { 'traintupleKey': traintuple_key, 'dataManagerKey': data_manager_key, - 'dataSampleKeys': ','.join(test_data_sample_keys), + 'dataSampleKeys': test_data_sample_keys, 'tag': tag } + return args def create(self, validated_data): args = self.get_args(validated_data) if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerTesttuple(args, sync=True) + data = createLedgerTesttuple(args, sync=True) else: # use a celery task, as we are in an http request transaction createLedgerTesttupleAsync.delay(args) - data = { - 'message': 'The substra network has been notified for adding this Testtuple. Please be aware you won\'t get return values from the ledger. You will need to check manually' + 'message': 'The substra network has been notified for adding this Testtuple. ' + 'Please be aware you won\'t get return values from the ledger. ' + 'You will need to check manually' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/testtuple/tasks.py b/backend/substrapp/serializers/ledger/testtuple/tasks.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/testtuple/tasks.py rename to backend/substrapp/serializers/ledger/testtuple/tasks.py diff --git a/backend/substrapp/serializers/ledger/testtuple/util.py b/backend/substrapp/serializers/ledger/testtuple/util.py new file mode 100644 index 000000000..bed664783 --- /dev/null +++ b/backend/substrapp/serializers/ledger/testtuple/util.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import, unicode_literals + + +from substrapp.ledger_utils import invoke_ledger, retry_on_error + + +@retry_on_error(nbtries=3) +def createLedgerTesttuple(args, sync=False): + return invoke_ledger(fcn='createTesttuple', args=args, sync=sync) diff --git a/backend/substrapp/serializers/ledger/traintuple/__init__.py b/backend/substrapp/serializers/ledger/traintuple/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/serializers/ledger/traintuple/serializer.py b/backend/substrapp/serializers/ledger/traintuple/serializer.py similarity index 65% rename from substrabac/substrapp/serializers/ledger/traintuple/serializer.py rename to backend/substrapp/serializers/ledger/traintuple/serializer.py index 970a67ba8..32edba29d 100644 --- a/substrabac/substrapp/serializers/ledger/traintuple/serializer.py +++ b/backend/substrapp/serializers/ledger/traintuple/serializer.py @@ -1,4 +1,4 @@ -from rest_framework import serializers, status +from rest_framework import serializers from django.conf import settings @@ -10,8 +10,8 @@ class LedgerTrainTupleSerializer(serializers.Serializer): algo_key = serializers.CharField(min_length=64, max_length=64) data_manager_key = serializers.CharField(min_length=64, max_length=64) objective_key = serializers.CharField(min_length=64, max_length=64) - rank = serializers.IntegerField(allow_null=True, required=False) - FLtask_key = serializers.CharField(min_length=64, max_length=64, allow_blank=True, required=False) + rank = serializers.IntegerField(allow_null=True, required=False, default=0) + compute_plan_id = serializers.CharField(min_length=64, max_length=64, allow_blank=True, required=False) in_models_keys = serializers.ListField(child=serializers.CharField(min_length=64, max_length=64), min_length=0, required=False) @@ -24,19 +24,19 @@ def get_args(self, validated_data): data_manager_key = validated_data.get('data_manager_key') objective_key = validated_data.get('objective_key') rank = validated_data.get('rank', '') - rank = '' if rank is None else rank # rank should be an integer or empty string, not None - FLtask_key = validated_data.get('FLtask_key', '') + rank = '' if rank is None else str(rank) + compute_plan_id = validated_data.get('compute_plan_id', '') train_data_sample_keys = validated_data.get('train_data_sample_keys', []) - in_models_keys = validated_data.get('in_models_keys') + in_models_keys = validated_data.get('in_models_keys', []) tag = validated_data.get('tag', '') - args = '"%(algoKey)s", "%(associatedObjective)s", "%(inModels)s", "%(dataManagerKey)s", "%(dataSampleKeys)s", "%(FLtask)s", "%(rank)s", "%(tag)s"' % { + args = { 'algoKey': algo_key, - 'associatedObjective': objective_key, - 'inModels': ','.join(in_models_keys), + 'objectiveKey': objective_key, + 'inModels': in_models_keys, 'dataManagerKey': data_manager_key, - 'dataSampleKeys': ','.join(train_data_sample_keys), - 'FLtask': FLtask_key, + 'dataSampleKeys': train_data_sample_keys, + 'computePlanID': compute_plan_id, 'rank': rank, 'tag': tag } @@ -47,13 +47,14 @@ def create(self, validated_data): args = self.get_args(validated_data) if getattr(settings, 'LEDGER_SYNC_ENABLED'): - return createLedgerTraintuple(args, sync=True) + data = createLedgerTraintuple(args, sync=True) else: # use a celery task, as we are in an http request transaction createLedgerTraintupleAsync.delay(args) - data = { - 'message': 'The substra network has been notified for adding this Traintuple. Please be aware you won\'t get return values from the ledger. You will need to check manually' + 'message': 'The substra network has been notified for adding this Traintuple. ' + 'Please be aware you won\'t get return values from the ledger. ' + 'You will need to check manually' } - st = status.HTTP_202_ACCEPTED - return data, st + + return data diff --git a/substrabac/substrapp/serializers/ledger/traintuple/tasks.py b/backend/substrapp/serializers/ledger/traintuple/tasks.py similarity index 100% rename from substrabac/substrapp/serializers/ledger/traintuple/tasks.py rename to backend/substrapp/serializers/ledger/traintuple/tasks.py diff --git a/backend/substrapp/serializers/ledger/traintuple/util.py b/backend/substrapp/serializers/ledger/traintuple/util.py new file mode 100644 index 000000000..124f6d54a --- /dev/null +++ b/backend/substrapp/serializers/ledger/traintuple/util.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import, unicode_literals + + +from substrapp.ledger_utils import invoke_ledger, retry_on_error + + +@retry_on_error(nbtries=3) +def createLedgerTraintuple(args, sync=False): + return invoke_ledger(fcn='createTraintuple', args=args, sync=sync) diff --git a/backend/substrapp/serializers/ledger/utils.py b/backend/substrapp/serializers/ledger/utils.py new file mode 100644 index 000000000..28b2e96a9 --- /dev/null +++ b/backend/substrapp/serializers/ledger/utils.py @@ -0,0 +1,59 @@ +from rest_framework import serializers +from django.core.exceptions import ObjectDoesNotExist +from substrapp.ledger_utils import invoke_ledger, LedgerError, LedgerTimeout + + +class PermissionsSerializer(serializers.Serializer): + public = serializers.BooleanField() + authorized_ids = serializers.ListField(child=serializers.CharField()) + + +def create_ledger_asset(model, fcn, args, pkhash, sync=False): + try: + instance = model.objects.get(pk=pkhash) + except ObjectDoesNotExist: + instance = None + + try: + data = invoke_ledger(fcn=fcn, args=args, sync=sync) + except LedgerTimeout: + # LedgerTimeout herits from LedgerError do not delete + # In case of timeout we keep the instance if it exists + raise + except LedgerError: + # if not created on ledger, delete from local db + if instance: + instance.delete() + raise + + if instance: + instance.validated = True + instance.save() + data['validated'] = True + + return data + + +def create_ledger_assets(model, fcn, args, pkhashes, sync=False): + try: + instances = model.objects.filter(pk__in=pkhashes) + except ObjectDoesNotExist: + instances = None + + try: + data = invoke_ledger(fcn=fcn, args=args, sync=sync) + except LedgerTimeout: + # LedgerTimeout herits from LedgerError do not delete + # In case of timeout we keep the instances if it exists + raise + except LedgerError: + # if not created on ledger, delete from local db + if instances: + instances.delete() + raise + + if instances: + instances.update(validated=True) + data['validated'] = True + + return data diff --git a/substrabac/substrapp/serializers/model.py b/backend/substrapp/serializers/model.py similarity index 100% rename from substrabac/substrapp/serializers/model.py rename to backend/substrapp/serializers/model.py diff --git a/substrabac/substrapp/serializers/objective.py b/backend/substrapp/serializers/objective.py similarity index 58% rename from substrabac/substrapp/serializers/objective.py rename to backend/substrapp/serializers/objective.py index ca6c49263..5c5eef6f4 100644 --- a/substrabac/substrapp/serializers/objective.py +++ b/backend/substrapp/serializers/objective.py @@ -1,8 +1,13 @@ +from rest_framework import serializers + from libs.serializers import DynamicFieldsModelSerializer from substrapp.models import Objective +from substrapp.serializers.utils import FileValidator + class ObjectiveSerializer(DynamicFieldsModelSerializer): + metrics = serializers.FileField(validators=[FileValidator()]) class Meta: model = Objective diff --git a/substrabac/substrapp/serializers/algo.py b/backend/substrapp/serializers/utils.py similarity index 67% rename from substrabac/substrapp/serializers/algo.py rename to backend/substrapp/serializers/utils.py index 19b1cc02c..3442cd491 100644 --- a/substrabac/substrapp/serializers/algo.py +++ b/backend/substrapp/serializers/utils.py @@ -2,10 +2,6 @@ import zipfile from django.core.exceptions import ValidationError -from rest_framework import serializers - -from libs.serializers import DynamicFieldsModelSerializer -from substrapp.models import Algo from django.utils.deconstruct import deconstructible @@ -13,10 +9,10 @@ @deconstructible class FileValidator(object): error_messages = { - 'open': ("Cannot handle this file object."), - 'compressed': ("Ensure this file is an archive (zip or tar.* compressed file)."), - 'docker': ("Ensure your archive contains a Dockerfile."), - 'file': ("Ensure your archive contains at least one algo file (for instance algo.py)."), + 'open': ("Cannot handle this file object."), + 'compressed': ("Ensure this file is an archive (zip or tar.* compressed file)."), + 'docker': ("Ensure your archive contains a Dockerfile."), + 'file': ("Ensure your archive contains at least one python file."), } def validate_archive(self, files): @@ -31,7 +27,7 @@ def __call__(self, data): archive = None try: data.file.seek(0) - except: + except Exception: raise ValidationError(self.error_messages['open']) else: try: @@ -51,11 +47,3 @@ def __call__(self, data): archive.close() else: raise ValidationError(self.error_messages['open']) - - -class AlgoSerializer(DynamicFieldsModelSerializer): - file = serializers.FileField(validators=[FileValidator()]) - - class Meta: - model = Algo - fields = '__all__' diff --git a/backend/substrapp/signals/__init__.py b/backend/substrapp/signals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/substrapp/signals/algo/__init__.py b/backend/substrapp/signals/algo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/signals/algo/post_delete.py b/backend/substrapp/signals/algo/post_delete.py similarity index 81% rename from substrabac/substrapp/signals/algo/post_delete.py rename to backend/substrapp/signals/algo/post_delete.py index 274af3e72..75411107e 100644 --- a/substrabac/substrapp/signals/algo/post_delete.py +++ b/backend/substrapp/signals/algo/post_delete.py @@ -8,4 +8,5 @@ def algo_post_delete(sender, instance, **kwargs): instance.description.delete(False) directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'algos/{0}'.format(instance.pk)) - shutil.rmtree(directory) + if path.exists(directory): + shutil.rmtree(directory) diff --git a/backend/substrapp/signals/datamanager/__init__.py b/backend/substrapp/signals/datamanager/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/signals/datamanager/post_delete.py b/backend/substrapp/signals/datamanager/post_delete.py similarity index 82% rename from substrabac/substrapp/signals/datamanager/post_delete.py rename to backend/substrapp/signals/datamanager/post_delete.py index e6642f3f1..ede34bfa4 100644 --- a/substrabac/substrapp/signals/datamanager/post_delete.py +++ b/backend/substrapp/signals/datamanager/post_delete.py @@ -8,4 +8,5 @@ def datamanager_post_delete(sender, instance, **kwargs): instance.description.delete(False) directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'datamanagers/{0}'.format(instance.pk)) - shutil.rmtree(directory) + if path.exists(directory): + shutil.rmtree(directory) diff --git a/backend/substrapp/signals/datasample/__init__.py b/backend/substrapp/signals/datasample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/signals/datasample/post_delete.py b/backend/substrapp/signals/datasample/post_delete.py similarity index 81% rename from substrabac/substrapp/signals/datasample/post_delete.py rename to backend/substrapp/signals/datasample/post_delete.py index 75cff65db..39a9ba6f1 100644 --- a/substrabac/substrapp/signals/datasample/post_delete.py +++ b/backend/substrapp/signals/datasample/post_delete.py @@ -7,4 +7,5 @@ def data_sample_post_delete(sender, instance, **kwargs): # remove created folder directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'datasamples', instance.pk) - rmtree(directory) + if path.exists(directory): + rmtree(directory) diff --git a/backend/substrapp/signals/datasample/pre_save.py b/backend/substrapp/signals/datasample/pre_save.py new file mode 100644 index 000000000..193e8fa57 --- /dev/null +++ b/backend/substrapp/signals/datasample/pre_save.py @@ -0,0 +1,18 @@ +from os import path, link +from os.path import normpath +import shutil +from django.conf import settings + + +def data_sample_pre_save(sender, instance, **kwargs): + directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'datasamples/{0}'.format(instance.pk)) + + # try to make an hard link to keep a free copy of the data + # if not possible, keep the real path location + try: + shutil.copytree(normpath(instance.path), directory, copy_function=link) + except Exception: + shutil.rmtree(directory, ignore_errors=True) + else: + # override path for getting our hardlink + instance.path = directory diff --git a/backend/substrapp/signals/model/__init__.py b/backend/substrapp/signals/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/signals/model/post_delete.py b/backend/substrapp/signals/model/post_delete.py similarity index 79% rename from substrabac/substrapp/signals/model/post_delete.py rename to backend/substrapp/signals/model/post_delete.py index d26eeb1c1..003c0100a 100644 --- a/substrabac/substrapp/signals/model/post_delete.py +++ b/backend/substrapp/signals/model/post_delete.py @@ -7,4 +7,5 @@ def model_post_delete(sender, instance, **kwargs): instance.file.delete(False) directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'models/{0}'.format(instance.pk)) - shutil.rmtree(directory) + if path.exists(directory): + shutil.rmtree(directory) diff --git a/backend/substrapp/signals/objective/__init__.py b/backend/substrapp/signals/objective/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/substrabac/substrapp/signals/objective/post_delete.py b/backend/substrapp/signals/objective/post_delete.py similarity index 82% rename from substrabac/substrapp/signals/objective/post_delete.py rename to backend/substrapp/signals/objective/post_delete.py index 5b1f71ced..fac5e04b4 100644 --- a/substrabac/substrapp/signals/objective/post_delete.py +++ b/backend/substrapp/signals/objective/post_delete.py @@ -8,4 +8,5 @@ def objective_post_delete(sender, instance, **kwargs): instance.metrics.delete(False) directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'objectives/{0}'.format(instance.pk)) - shutil.rmtree(directory) + if path.exists(directory): + shutil.rmtree(directory) diff --git a/backend/substrapp/tasks/__init__.py b/backend/substrapp/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/substrapp/tasks/exception_handler.py b/backend/substrapp/tasks/exception_handler.py new file mode 100644 index 000000000..e97288240 --- /dev/null +++ b/backend/substrapp/tasks/exception_handler.py @@ -0,0 +1,150 @@ +import os +import uuid +import docker.errors +import traceback +import json +import re +import inspect + + +LANGUAGES = { + 'ShellScript': '00', + 'Python': '01' +} + +SERVICES = { + 'System': '00', + 'Docker': '01' +} + +EXCEPTION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'exceptions.json') + +EXCEPTIONS_UUID_LENGTH = 7 + +if os.path.exists(EXCEPTION_PATH): + try: + EXCEPTIONS_MAP = json.load(open(EXCEPTION_PATH)) + except Exception: + # The json may be corrupted + EXCEPTIONS_MAP = dict() +else: + EXCEPTIONS_MAP = dict() + + +def get_exception_codes_from_docker_trace(): + container_code = EXCEPTIONS_MAP[docker.errors.ContainerError.__name__] + + # Get last line of the docker traceback which contains the traceback inside the container + docker_traceback = traceback.format_exc().splitlines()[-1].encode('utf_8').decode('unicode_escape') + docker_traceback = re.split(':| |\n', docker_traceback) + + exception_codes = [code for exception, code in EXCEPTIONS_MAP.items() + if exception in docker_traceback and code != container_code] + + return exception_codes + + +def get_exception_code(exception_type): + + service_code = SERVICES['System'] + exception_code = EXCEPTIONS_MAP.get(exception_type.__name__, '0000') # '0000' is default exception code + + # Exception inside a docker container + if docker.errors.ContainerError.__name__ in EXCEPTIONS_MAP and \ + exception_code == EXCEPTIONS_MAP[docker.errors.ContainerError.__name__]: + + exception_codes = get_exception_codes_from_docker_trace() + + if len(exception_codes) > 0: + # Take the first code in the list (may have more if multiple exceptions are raised) + service_code = SERVICES['Docker'] + exception_code = exception_codes.pop() + + return exception_code, service_code + + +def compute_error_code(exception): + exception_uuid = str(uuid.uuid4())[:EXCEPTIONS_UUID_LENGTH] + exception_code, service_code = get_exception_code(exception.__class__) + error_code = f'[{service_code}-{LANGUAGES["Python"]}-{exception_code}-{exception_uuid}]' + return error_code + + +def exception_tree(cls, exceptions_classes): + exceptions_classes.add(cls.__name__) + for subcls in cls.__subclasses__(): + exception_tree(subcls, exceptions_classes) + + +def find_exception(module): + # Exception classes in module + exceptions = [ename for ename, eclass in inspect.getmembers(module, inspect.isclass) + if issubclass(eclass, BaseException)] + + # Exception classes in submodule + for submodule in inspect.getmembers(module, inspect.ismodule): + exceptions += [ename for ename, eclass in inspect.getmembers(module, inspect.isclass) + if issubclass(eclass, BaseException)] + + return set(exceptions) + + +def generate_exceptions_map(append=True): + + os.environ['DJANGO_SETTINGS_MODULE'] = 'backend.settings.prod' + + import requests.exceptions + import celery.exceptions + import tarfile + import django.core.exceptions + import django.urls + import django.db + import django.http + import django.db.transaction + import rest_framework.exceptions + + # Modules to inspect + MODULES = [docker.errors, requests.exceptions, celery.exceptions, tarfile, + django.core.exceptions, django.urls, django.db, django.http, django.db.transaction, + rest_framework.exceptions] + + exceptions_classes = set() + + # Add exceptions from modules + for errors_module in MODULES: + exceptions_classes.update(find_exception(errors_module)) + + # Add exceptions from python + exception_tree(BaseException, exceptions_classes) + + exceptions_classes = sorted(exceptions_classes) + + if os.path.exists(EXCEPTION_PATH) and append: + # Append values to it + json_exceptions = json.load(open(EXCEPTION_PATH)) + + # get all new exceptions + exceptions_classes = [e for e in exceptions_classes if e not in json_exceptions.keys()] + + # get the last value + start_value = max(map(int, json_exceptions.values())) + + for code_exception, exception_name in enumerate(exceptions_classes, start=start_value + 1): + json_exceptions[exception_name] = f'{code_exception:04d}' + + return json_exceptions + + else: + # Generate the json exceptions + json_exceptions = dict() + for code_exception, exception_name in enumerate(exceptions_classes, start=1): + json_exceptions[exception_name] = f'{code_exception:04d}' + + return json_exceptions + + +if __name__ == '__main__': + os.environ['DJANGO_SETTINGS_MODULE'] = 'backend.settings.common' + json_exceptions = generate_exceptions_map() + with open(EXCEPTION_PATH, 'w') as outfile: + json.dump(json_exceptions, outfile, indent=4) diff --git a/substrabac/substrapp/exceptions.json b/backend/substrapp/tasks/exceptions.json similarity index 99% rename from substrabac/substrapp/exceptions.json rename to backend/substrapp/tasks/exceptions.json index b3e1ad106..aea02b75a 100644 --- a/substrabac/substrapp/exceptions.json +++ b/backend/substrapp/tasks/exceptions.json @@ -480,5 +480,7 @@ "error": "0480", "gaierror": "0481", "herror": "0482", - "timeout": "0483" -} + "timeout": "0483", + "MessageNacked": "0484", + "PkgResourcesDeprecationWarning": "0485" +} \ No newline at end of file diff --git a/backend/substrapp/tasks/tasks.py b/backend/substrapp/tasks/tasks.py new file mode 100644 index 000000000..203a7b00d --- /dev/null +++ b/backend/substrapp/tasks/tasks.py @@ -0,0 +1,493 @@ +from __future__ import absolute_import, unicode_literals + +import os +import shutil +import tempfile +from os import path +import json +from multiprocessing.managers import BaseManager +import logging + +import docker +from checksumdir import dirhash +from django.core.exceptions import ObjectDoesNotExist +from django.conf import settings +from rest_framework.reverse import reverse +from celery.result import AsyncResult +from celery.exceptions import Ignore + +from backend.celery import app +from substrapp.utils import get_hash, get_owner, create_directory, uncompress_content +from substrapp.ledger_utils import (log_start_tuple, log_success_tuple, log_fail_tuple, + query_tuples, LedgerError, LedgerStatusError, get_object_from_ledger) +from substrapp.tasks.utils import ResourcesManager, compute_docker, get_asset_content +from substrapp.tasks.exception_handler import compute_error_code + + +def get_objective(subtuple): + from substrapp.models import Objective + + objective_hash = subtuple['objective']['hash'] + + try: + objective = Objective.objects.get(pk=objective_hash) + except ObjectDoesNotExist: + objective = None + + # get objective from ledger as it is not available in local db and store it in local db + if objective is None or not objective.metrics: + objective_metadata = get_object_from_ledger(objective_hash, 'queryObjective') + + content = get_asset_content( + objective_metadata['metrics']['storageAddress'], + objective_metadata['owner'], + objective_metadata['metrics']['hash'], + ) + + objective, _ = Objective.objects.update_or_create(pkhash=objective_hash, validated=True) + + tmp_file = tempfile.TemporaryFile() + tmp_file.write(content) + objective.metrics.save('metrics.archive', tmp_file) + + return objective.metrics.read() + + +def get_algo(subtuple): + algo_hash = subtuple['algo']['hash'] + algo_metadata = get_object_from_ledger(algo_hash, 'queryAlgo') + + algo_content = get_asset_content( + algo_metadata['content']['storageAddress'], + algo_metadata['owner'], + algo_metadata['content']['hash'], + ) + + return algo_content + + +def _get_model(model): + traintuple_hash = model['traintupleKey'] + traintuple_metadata = get_object_from_ledger(traintuple_hash, 'queryTraintuple') + + model_content = get_asset_content( + traintuple_metadata['outModel']['storageAddress'], + traintuple_metadata['dataset']['worker'], + traintuple_metadata['outModel']['hash'], + salt=traintuple_hash, + ) + + return model_content + + +def get_model(subtuple): + model = subtuple.get('model') + if model: + return _get_model(model) + else: + return None + + +def get_models(subtuple): + input_models = subtuple.get('inModels') + if input_models: + return [_get_model(item) for item in input_models] + else: + return [] + + +def _put_model(subtuple, subtuple_directory, model_content, model_hash, traintuple_key): + if not model_content: + raise Exception('Model content should not be empty') + + from substrapp.models import Model + + # store a model in local subtuple directory from input model content + model_dst_path = path.join(subtuple_directory, f'model/{traintuple_key}') + model = None + try: + model = Model.objects.get(pk=model_hash) + except ObjectDoesNotExist: # write it to local disk + with open(model_dst_path, 'wb') as f: + f.write(model_content) + else: + # verify that local db model file is not corrupted + if get_hash(model.file.path, traintuple_key) != model_hash: + raise Exception('Model Hash in Subtuple is not the same as in local db') + + if not os.path.exists(model_dst_path): + os.link(model.file.path, model_dst_path) + else: + # verify that local subtuple model file is not corrupted + if get_hash(model_dst_path, traintuple_key) != model_hash: + raise Exception('Model Hash in Subtuple is not the same as in local medias') + + +def put_model(subtuple, subtuple_directory, model_content): + return _put_model(subtuple, subtuple_directory, model_content, subtuple['model']['hash'], + subtuple['model']['traintupleKey']) + + +def put_models(subtuple, subtuple_directory, models_content): + if not models_content: + raise Exception('Models content should not be empty') + + for model_content, model in zip(models_content, subtuple['inModels']): + _put_model(model, subtuple_directory, model_content, model['hash'], model['traintupleKey']) + + +def put_opener(subtuple, subtuple_directory): + from substrapp.models import DataManager + data_opener_hash = subtuple['dataset']['openerHash'] + + datamanager = DataManager.objects.get(pk=data_opener_hash) + + # verify that local db opener file is not corrupted + if get_hash(datamanager.data_opener.path) != data_opener_hash: + raise Exception('DataOpener Hash in Subtuple is not the same as in local db') + + opener_dst_path = path.join(subtuple_directory, 'opener/opener.py') + if not os.path.exists(opener_dst_path): + os.link(datamanager.data_opener.path, opener_dst_path) + else: + # verify that local subtuple data opener file is not corrupted + if get_hash(opener_dst_path) != data_opener_hash: + raise Exception('DataOpener Hash in Subtuple is not the same as in local medias') + + +def put_data_sample(subtuple, subtuple_directory): + from substrapp.models import DataSample + + for data_sample_key in subtuple['dataset']['keys']: + data_sample = DataSample.objects.get(pk=data_sample_key) + data_sample_hash = dirhash(data_sample.path, 'sha256') + if data_sample_hash != data_sample_key: + raise Exception('Data Sample Hash in Subtuple is not the same as in local db') + + # create a symlink on the folder containing data + subtuple_data_directory = path.join(subtuple_directory, 'data', data_sample_key) + try: + os.symlink(data_sample.path, subtuple_data_directory) + except OSError as e: + logging.exception(e) + raise Exception('Failed to create sym link for subtuple data sample') + + +def put_metric(subtuple_directory, metrics_content): + metrics_dst_path = path.join(subtuple_directory, 'metrics/') + uncompress_content(metrics_content, metrics_dst_path) + + +def put_algo(subtuple_directory, algo_content): + uncompress_content(algo_content, subtuple_directory) + + +def build_subtuple_folders(subtuple): + # create a folder named `subtuple['key']` in /medias/subtuple/ with 5 subfolders opener, data, model, pred, metrics + subtuple_directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'subtuple', subtuple['key']) + create_directory(subtuple_directory) + + for folder in ['opener', 'data', 'model', 'pred', 'metrics']: + create_directory(path.join(subtuple_directory, folder)) + + return subtuple_directory + + +def remove_subtuple_materials(subtuple_directory): + try: + shutil.rmtree(subtuple_directory) + except Exception as e: + logging.exception(e) + + +# Instatiate Ressource Manager in BaseManager to share it between celery concurrent tasks +BaseManager.register('ResourcesManager', ResourcesManager) +manager = BaseManager() +manager.start() +resources_manager = manager.ResourcesManager() + + +@app.task(ignore_result=True) +def prepare_training_task(): + prepare_task('traintuple') + + +@app.task(ignore_result=True) +def prepare_testing_task(): + prepare_task('testtuple') + + +def prepare_task(tuple_type): + data_owner = get_owner() + worker_queue = f"{settings.LEDGER['name']}.worker" + tuples = query_tuples(tuple_type, data_owner) + + for subtuple in tuples: + tkey = subtuple['key'] + # Verify that tuple task does not already exist + if AsyncResult(tkey).state == 'PENDING': + prepare_tuple.apply_async( + (subtuple, tuple_type), + task_id=tkey, + queue=worker_queue + ) + else: + print(f'[Scheduler] Tuple task ({tkey}) already exists') + + +@app.task(ignore_result=False) +def prepare_tuple(subtuple, tuple_type): + from django_celery_results.models import TaskResult + + compute_plan_id = None + worker_queue = f"{settings.LEDGER['name']}.worker" + + if 'computePlanID' in subtuple and subtuple['computePlanID']: + compute_plan_id = subtuple['computePlanID'] + flresults = TaskResult.objects.filter( + task_name='substrapp.tasks.tasks.compute_task', + result__icontains=f'"computePlanID": "{compute_plan_id}"') + + if flresults and flresults.count() > 0: + worker_queue = json.loads(flresults.first().as_dict()['result'])['worker'] + + try: + log_start_tuple(tuple_type, subtuple['key']) + except LedgerStatusError as e: + # Do not log_fail_tuple in this case, because prepare_tuple task are not unique + # in case of multiple instances of substra backend running for the same organisation + # So prepare_tuple tasks are ignored if it cannot log_start_tuple + logging.exception(e) + raise Ignore() + + try: + compute_task.apply_async( + (tuple_type, subtuple, compute_plan_id), + queue=worker_queue) + except Exception as e: + error_code = compute_error_code(e) + logging.error(error_code, exc_info=True) + log_fail_tuple(tuple_type, subtuple['key'], error_code) + + +@app.task(bind=True, ignore_result=False) +def compute_task(self, tuple_type, subtuple, compute_plan_id): + + try: + worker = self.request.hostname.split('@')[1] + queue = self.request.delivery_info['routing_key'] + except Exception: + worker = f"{settings.LEDGER['name']}.worker" + queue = f"{settings.LEDGER['name']}" + + result = {'worker': worker, 'queue': queue, 'computePlanID': compute_plan_id} + + try: + prepare_materials(subtuple, tuple_type) + res = do_task(subtuple, tuple_type) + except Exception as e: + error_code = compute_error_code(e) + logging.error(error_code, exc_info=True) + + try: + log_fail_tuple(tuple_type, subtuple['key'], error_code) + except LedgerError as e: + logging.exception(e) + + return result + + try: + log_success_tuple(tuple_type, subtuple['key'], res) + except LedgerError as e: + logging.exception(e) + + return result + + +def prepare_materials(subtuple, tuple_type): + + # get subtuple components + metrics_content = get_objective(subtuple) + algo_content = get_algo(subtuple) + if tuple_type == 'testtuple': + model_content = get_model(subtuple) + elif tuple_type == 'traintuple': + models_content = get_models(subtuple) + else: + raise NotImplementedError() + + # create subtuple + subtuple_directory = build_subtuple_folders(subtuple) + put_opener(subtuple, subtuple_directory) + put_data_sample(subtuple, subtuple_directory) + put_metric(subtuple_directory, metrics_content) + put_algo(subtuple_directory, algo_content) + if tuple_type == 'testtuple': + put_model(subtuple, subtuple_directory, model_content) + elif tuple_type == 'traintuple' and models_content: + put_models(subtuple, subtuple_directory, models_content) + + logging.info(f'Prepare materials for {tuple_type} task: success ') + + +def do_task(subtuple, tuple_type): + subtuple_directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'subtuple', subtuple['key']) + org_name = getattr(settings, 'ORG_NAME') + + # compute plan / federated learning variables + compute_plan_id = None + rank = None + + if 'computePlanID' in subtuple and subtuple['computePlanID']: + compute_plan_id = subtuple['computePlanID'] + rank = int(subtuple['rank']) + + client = docker.from_env() + + try: + result = _do_task( + client, + subtuple_directory, + tuple_type, + subtuple, + compute_plan_id, + rank, + org_name + ) + except Exception as e: + if compute_plan_id is not None: + rank = -1 # -1 means last subtuple in the compute plan + raise e + finally: + # Clean subtuple materials + if settings.TASK['CLEAN_EXECUTION_ENVIRONMENT']: + remove_subtuple_materials(subtuple_directory) + if rank == -1: + volume_id = f'local-{compute_plan_id}-{org_name}' + local_volume = client.volumes.get(volume_id=volume_id) + try: + local_volume.remove(force=True) + except Exception: + logging.error(f'Cannot remove local volume {volume_id}', exc_info=True) + + return result + + +def _do_task(client, subtuple_directory, tuple_type, subtuple, compute_plan_id, rank, org_name): + + model_path = path.join(subtuple_directory, 'model') + data_path = path.join(subtuple_directory, 'data') + pred_path = path.join(subtuple_directory, 'pred') + opener_file = path.join(subtuple_directory, 'opener/opener.py') + algo_path = path.join(subtuple_directory) + + algo_docker = f'substra/algo_{subtuple["key"][0:8]}'.lower() # tag must be lowercase for docker + algo_docker_name = f'{tuple_type}_{subtuple["key"][0:8]}' + + remove_image = not((compute_plan_id is not None and rank != -1) or settings.TASK['CACHE_DOCKER_IMAGES']) + + # VOLUMES + + symlinks_volume = {} + for subfolder in os.listdir(data_path): + real_path = os.path.realpath(os.path.join(data_path, subfolder)) + symlinks_volume[real_path] = {'bind': f'{real_path}', 'mode': 'ro'} + + volumes = { + data_path: {'bind': '/sandbox/data', 'mode': 'ro'}, + pred_path: {'bind': '/sandbox/pred', 'mode': 'rw'}, + opener_file: {'bind': '/sandbox/opener/__init__.py', 'mode': 'ro'} + } + + model_volume = { + model_path: {'bind': '/sandbox/model', 'mode': 'rw'} + } + + # local volume for training subtuple in compute plan + if compute_plan_id is not None and tuple_type == 'traintuple': + volume_id = f'local-{compute_plan_id}-{org_name}' + if rank == 0: + client.volumes.create(name=volume_id) + else: + client.volumes.get(volume_id=volume_id) + model_volume[volume_id] = {'bind': '/sandbox/local', 'mode': 'rw'} + + # generate command + if tuple_type == 'traintuple': + command = 'train' + algo_docker_name = f'{algo_docker_name}_{command}' + + if subtuple['inModels'] is not None: + inmodels = [subtuple_model["traintupleKey"] for subtuple_model in subtuple['inModels']] + command = f"{command} {' '.join(inmodels)}" + + if rank is not None: + command = f"{command} --rank {rank}" + + elif tuple_type == 'testtuple': + command = 'predict' + algo_docker_name = f'{algo_docker_name}_{command}' + inmodels = subtuple['model']["traintupleKey"] + command = f'{command} {inmodels}' + + compute_docker( + client=client, + resources_manager=resources_manager, + dockerfile_path=algo_path, + image_name=algo_docker, + container_name=algo_docker_name, + volumes={**volumes, **model_volume, **symlinks_volume}, + command=command, + remove_image=remove_image, + remove_container=settings.TASK['CLEAN_EXECUTION_ENVIRONMENT'], + capture_logs=settings.TASK['CAPTURE_LOGS'] + ) + + # save model in database + if tuple_type == 'traintuple': + end_model_file, end_model_file_hash = save_model(subtuple_directory, subtuple['key']) + + # evaluation + metrics_path = f'{subtuple_directory}/metrics' + eval_docker = f'substra/metrics_{subtuple["key"][0:8]}'.lower() # tag must be lowercase for docker + eval_docker_name = f'{tuple_type}_{subtuple["key"][0:8]}_eval' + + compute_docker( + client=client, + resources_manager=resources_manager, + dockerfile_path=metrics_path, + image_name=eval_docker, + container_name=eval_docker_name, + volumes={**volumes, **symlinks_volume}, + command=None, + remove_image=remove_image, + remove_container=settings.TASK['CLEAN_EXECUTION_ENVIRONMENT'], + capture_logs=settings.TASK['CAPTURE_LOGS'] + ) + + # load performance + with open(path.join(pred_path, 'perf.json'), 'r') as perf_file: + perf = json.load(perf_file) + global_perf = perf['all'] + + result = {'global_perf': global_perf} + + if tuple_type == 'traintuple': + result['end_model_file_hash'] = end_model_file_hash + result['end_model_file'] = end_model_file + + return result + + +def save_model(subtuple_directory, subtuple_key): + from substrapp.models import Model + end_model_path = path.join(subtuple_directory, 'model/model') + end_model_file_hash = get_hash(end_model_path, subtuple_key) + instance = Model.objects.create(pkhash=end_model_file_hash, validated=True) + + with open(end_model_path, 'rb') as f: + instance.file.save('model', f) + current_site = getattr(settings, "DEFAULT_DOMAIN") + end_model_file = f'{current_site}{reverse("substrapp:model-file", args=[end_model_file_hash])}' + + return end_model_file, end_model_file_hash diff --git a/backend/substrapp/tasks/utils.py b/backend/substrapp/tasks/utils.py new file mode 100644 index 000000000..9eec36a15 --- /dev/null +++ b/backend/substrapp/tasks/utils.py @@ -0,0 +1,251 @@ +import os +import docker +import GPUtil as gputil +import threading + +import logging + +from subprocess import check_output +from django.conf import settings +from requests.auth import HTTPBasicAuth +from substrapp.utils import get_owner, get_remote_file_content, NodeError + + +DOCKER_LABEL = 'substra_task' + +logger = logging.getLogger(__name__) + + +def authenticate_worker(node_id): + from node.models import OutgoingNode + + owner = get_owner() + + try: + outgoing = OutgoingNode.objects.get(node_id=node_id) + except OutgoingNode.DoesNotExist: + raise NodeError(f'Unauthorized to call node_id: {node_id}') + + auth = HTTPBasicAuth(owner, outgoing.secret) + + return auth + + +def get_asset_content(url, node_id, content_hash, salt=None): + return get_remote_file_content(url, authenticate_worker(node_id), content_hash, salt=salt) + + +def get_cpu_sets(cpu_count, concurrency): + cpu_step = max(1, cpu_count // concurrency) + cpu_sets = [] + + for cpu_start in range(0, cpu_count, cpu_step): + cpu_set = f'{cpu_start}-{min(cpu_start + cpu_step - 1, cpu_count - 1)}' + cpu_sets.append(cpu_set) + if len(cpu_sets) == concurrency: + break + + return cpu_sets + + +def get_gpu_sets(gpu_list, concurrency): + + if gpu_list: + gpu_count = len(gpu_list) + gpu_step = max(1, gpu_count // concurrency) + gpu_sets = [] + + for igpu_start in range(0, gpu_count, gpu_step): + gpu_sets.append(','.join(gpu_list[igpu_start: igpu_start + gpu_step])) + else: + gpu_sets = None + + return gpu_sets + + +def expand_cpu_set(cpu_set): + cpu_set_start, cpu_set_stop = map(int, cpu_set.split('-')) + return set(range(cpu_set_start, cpu_set_stop + 1)) + + +def reduce_cpu_set(expanded_cpu_set): + return f'{min(expanded_cpu_set)}-{max(expanded_cpu_set)}' + + +def expand_gpu_set(gpu_set): + return set(gpu_set.split(',')) + + +def reduce_gpu_set(expanded_gpu_set): + return ','.join(sorted(expanded_gpu_set)) + + +def filter_resources_sets(used_resources_sets, resources_sets, expand_resources_set, reduce_resources_set): + """ Filter resources_set used with resources_sets defined. + It will block a resources_set from resources_sets if an used_resources_set in a subset of a resources_set""" + + resources_expand = [expand_resources_set(resources_set) for resources_set in resources_sets] + used_resources_expand = [expand_resources_set(used_resources_set) for used_resources_set in used_resources_sets] + + real_used_resources_sets = [] + + for resources_set in resources_expand: + for used_resources_set in used_resources_expand: + if resources_set.intersection(used_resources_set): + real_used_resources_sets.append(reduce_resources_set(resources_set)) + break + + return list(set(resources_sets).difference(set(real_used_resources_sets))) + + +def filter_cpu_sets(used_cpu_sets, cpu_sets): + return filter_resources_sets(used_cpu_sets, cpu_sets, expand_cpu_set, reduce_cpu_set) + + +def filter_gpu_sets(used_gpu_sets, gpu_sets): + return filter_resources_sets(used_gpu_sets, gpu_sets, expand_gpu_set, reduce_gpu_set) + + +def container_format_log(container_name, container_logs): + logs = [f'[{container_name}] {log}' for log in container_logs.decode().split('\n')] + for log in logs: + logger.info(log) + + +def compute_docker(client, resources_manager, dockerfile_path, image_name, container_name, volumes, command, + remove_image=True, remove_container=True, capture_logs=True): + + dockerfile_fullpath = os.path.join(dockerfile_path, 'Dockerfile') + if not os.path.exists(dockerfile_fullpath): + raise Exception(f'Dockerfile does not exist : {dockerfile_fullpath}') + + try: + client.images.build(path=dockerfile_path, + tag=image_name, + rm=remove_image) + except docker.errors.BuildError as e: + # catch build errors and print them for easier debugging of failed build + lines = [line['stream'].strip() for line in e.build_log if 'stream' in line] + lines = [l for l in lines if l] + error = '\n'.join(lines) + logger.error(f'BuildError: {error}') + raise + + # Limit ressources + memory_limit_mb = f'{resources_manager.memory_limit_mb()}M' + cpu_set, gpu_set = resources_manager.get_cpu_gpu_sets() # blocking call + + task_args = { + 'image': image_name, + 'name': container_name, + 'cpuset_cpus': cpu_set, + 'mem_limit': memory_limit_mb, + 'command': command, + 'volumes': volumes, + 'shm_size': '8G', + 'labels': [DOCKER_LABEL], + 'detach': False, + 'stdout': capture_logs, + 'stderr': capture_logs, + 'auto_remove': False, + 'remove': False, + 'network_disabled': True, + 'network_mode': 'none', + 'privileged': False, + 'cap_drop': ['ALL'] + } + + if gpu_set is not None: + task_args['environment'] = {'NVIDIA_VISIBLE_DEVICES': gpu_set} + task_args['runtime'] = 'nvidia' + + try: + client.containers.run(**task_args) + finally: + # we need to remove the containers to be able to remove the local + # volume in case of compute plan + container = client.containers.get(container_name) + if capture_logs: + container_format_log( + container_name, + container.logs() + ) + container.remove() + + # Remove images + if remove_image: + client.images.remove(image_name, force=True) + + +class ResourcesManager(): + + __concurrency = int(getattr(settings, 'CELERY_WORKER_CONCURRENCY')) + __cpu_count = os.cpu_count() + __cpu_sets = get_cpu_sets(__cpu_count, __concurrency) + + # Set CUDA_DEVICE_ORDER so the IDs assigned by CUDA match those from nvidia-smi + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + __gpu_list = [str(gpu.id) for gpu in gputil.getGPUs()] + __gpu_sets = get_gpu_sets(__gpu_list, __concurrency) # Can be None if no gpu + + __lock = threading.Lock() + __docker = docker.from_env() + + @classmethod + def memory_limit_mb(cls): + try: + return int(os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') / (1024. ** 2)) // cls.__concurrency + except ValueError: + # fixes macOS issue https://github.com/SubstraFoundation/substra-backend/issues/262 + return int(check_output(['sysctl', '-n', 'hw.memsize']).strip()) // cls.__concurrency + + @classmethod + def get_cpu_gpu_sets(cls): + + cpu_set = None + gpu_set = None + + with cls.__lock: + # We can just wait for cpu because cpu and gpu is allocated the same way + while cpu_set is None: + + # Get ressources used + filters = {'status': 'running', + 'label': [DOCKER_LABEL]} + + try: + containers = [container.attrs + for container in cls.__docker.containers.list(filters=filters)] + except docker.errors.APIError as e: + logger.error(e, exc_info=True) + continue + + # CPU + used_cpu_sets = [container['HostConfig']['CpusetCpus'] + for container in containers + if container['HostConfig']['CpusetCpus']] + + cpu_sets_available = filter_cpu_sets(used_cpu_sets, cls.__cpu_sets) + + if cpu_sets_available: + cpu_set = cpu_sets_available.pop() + + # GPU + if cls.__gpu_sets is not None: + env_containers = [container['Config']['Env'] + for container in containers] + + used_gpu_sets = [] + + for env_list in env_containers: + nvidia_env_var = [s.split('=')[1] + for s in env_list if "NVIDIA_VISIBLE_DEVICES" in s] + + used_gpu_sets.extend(nvidia_env_var) + + gpu_sets_available = filter_gpu_sets(used_gpu_sets, cls.__gpu_sets) + + if gpu_sets_available: + gpu_set = gpu_sets_available.pop() + + return cpu_set, gpu_set diff --git a/backend/substrapp/tests/__init__.py b/backend/substrapp/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/substrapp/tests/assets.py b/backend/substrapp/tests/assets.py new file mode 100644 index 000000000..044a96213 --- /dev/null +++ b/backend/substrapp/tests/assets.py @@ -0,0 +1,607 @@ +""" +WARNING +======= + +DO NOT MANUALLY EDIT THIS FILE! + +It is generated using substrapp/tests/generate_assets.py + +In order to update this file: +1. start a clean instance of substra +2. run populate.py +3. run substrapp/tests/generate_assets.py +""" + +objective = [ + { + "key": "1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3", + "name": "Skin Lesion Classification Objective", + "description": { + "hash": "1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3", + "storageAddress": "http://testserver/objective/1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3/description/" + }, + "metrics": { + "name": "macro-average recall", + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3/metrics/" + }, + "owner": "owkinMSP", + "testDataset": None, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + } + }, + { + "key": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "name": "Skin Lesion Classification Objective", + "description": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/description/" + }, + "metrics": { + "name": "macro-average recall", + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + }, + "owner": "owkinMSP", + "testDataset": { + "dataManagerKey": "ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932", + "dataSampleKeys": [ + "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf", + "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a" + ], + "worker": "" + }, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + } + } +] + +datamanager = [ + { + "objectiveKey": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "description": { + "hash": "15863c2af1fcfee9ca6f61f04be8a0eaaf6a45e4d50c421788d450d198e580f1", + "storageAddress": "http://testserver/data_manager/8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca/description/" + }, + "key": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "name": "ISIC 2018", + "opener": { + "hash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "storageAddress": "http://testserver/data_manager/8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca/opener/" + }, + "owner": "chu-nantesMSP", + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "type": "Images" + }, + { + "objectiveKey": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "description": { + "hash": "258bef187a166b3fef5cb86e68c8f7e154c283a148cd5bc344fec7e698821ad3", + "storageAddress": "http://testserver/data_manager/ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932/description/" + }, + "key": "ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932", + "name": "Simplified ISIC 2018", + "opener": { + "hash": "ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932", + "storageAddress": "http://testserver/data_manager/ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932/opener/" + }, + "owner": "owkinMSP", + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "type": "Images" + } +] + +algo = [ + { + "key": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", + "name": "Neural Network", + "content": { + "hash": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", + "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/file/" + }, + "description": { + "hash": "b9463411a01ea00869bdffce6e59a5c100a4e635c0a9386266cad3c77eb28e9e", + "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/description/" + }, + "owner": "chu-nantesMSP", + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + } + }, + { + "key": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", + "name": "Random Forest", + "content": { + "hash": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", + "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/file/" + }, + "description": { + "hash": "4acea40c4b51996c88ef279c5c9aa41ab77b97d38c5ca167e978a98b2e402675", + "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/description/" + }, + "owner": "chu-nantesMSP", + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + } + }, + { + "key": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "name": "Logistic regression", + "content": { + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "description": { + "hash": "124a0425b746d7072282d167b53cb6aab3a31bf1946dae89135c15b0126ebec3", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/description/" + }, + "owner": "owkinMSP", + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + } + } +] + +traintuple = [ + { + "key": "363f70dcc3bf22fdce65e36c957e855b7cd3e2828e6909f34ccc97ee6218541a", + "algo": { + "name": "Neural Network", + "hash": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", + "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 0 + }, + "computePlanID": "", + "inModels": None, + "log": "[00-01-0032-e18ebeb]", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": None, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "failed", + "tag": "My super tag" + }, + { + "key": "05b44fa4b94d548e35922629f7b23dd84f777d09925bbecb0362081ca528f746", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 1 + }, + "computePlanID": "", + "inModels": None, + "log": "", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": { + "hash": "e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99", + "storageAddress": "http://testserver/model/e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99/file/" + }, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "done", + "tag": "substra" + }, + { + "key": "32070e156eb4f97d85ff8448ea2ab71f4f275ab845159029354e4446aff974e0", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "e3644123451975be20909fcfd9c664a0573d9bfe04c5021625412d78c3536f1c" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 1 + }, + "computePlanID": "32070e156eb4f97d85ff8448ea2ab71f4f275ab845159029354e4446aff974e0", + "inModels": None, + "log": "", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": { + "hash": "0b1ce6f2bd9247a262c3695aa07aad5ef187197f118c73c60a42e176f8f53b98", + "storageAddress": "http://testserver/model/0b1ce6f2bd9247a262c3695aa07aad5ef187197f118c73c60a42e176f8f53b98/file/" + }, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "done", + "tag": "" + }, + { + "key": "a2171a1c09738c677748346d22d2b5eea47f874a3b4f4b75224674235892de72", + "algo": { + "name": "Random Forest", + "hash": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", + "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 0 + }, + "computePlanID": "", + "inModels": None, + "log": "[00-01-0032-8189cc5]", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": None, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "failed", + "tag": "" + } +] + +testtuple = [ + { + "key": "b2d127e65583080bf85d51f4bbc6b04e420414dd668f921c419eb6f078e428ae", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "certified": True, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "owkinMSP", + "keys": [ + "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a", + "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf" + ], + "openerHash": "ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932", + "perf": 0 + }, + "log": "", + "model": { + "traintupleKey": "05b44fa4b94d548e35922629f7b23dd84f777d09925bbecb0362081ca528f746", + "hash": "e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99", + "storageAddress": "http://testserver/model/e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99/file/" + }, + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "status": "done", + "tag": "substra" + } +] + +model = [ + { + "traintuple": { + "key": "363f70dcc3bf22fdce65e36c957e855b7cd3e2828e6909f34ccc97ee6218541a", + "algo": { + "name": "Neural Network", + "hash": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", + "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 0 + }, + "computePlanID": "", + "inModels": None, + "log": "[00-01-0032-e18ebeb]", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": None, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "failed", + "tag": "My super tag" + }, + "testtuple": { + "key": "", + "algo": None, + "certified": False, + "creator": "", + "dataset": None, + "log": "", + "model": None, + "objective": None, + "status": "", + "tag": "" + } + }, + { + "traintuple": { + "key": "05b44fa4b94d548e35922629f7b23dd84f777d09925bbecb0362081ca528f746", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 1 + }, + "computePlanID": "", + "inModels": None, + "log": "", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": { + "hash": "e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99", + "storageAddress": "http://testserver/model/e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99/file/" + }, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "done", + "tag": "substra" + }, + "testtuple": { + "key": "b2d127e65583080bf85d51f4bbc6b04e420414dd668f921c419eb6f078e428ae", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "certified": True, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "owkinMSP", + "keys": [ + "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a", + "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf" + ], + "openerHash": "ce9f292c72e9b82697445117f9c2d1d18ce0f8ed07ff91dadb17d668bddf8932", + "perf": 0 + }, + "log": "", + "model": { + "traintupleKey": "05b44fa4b94d548e35922629f7b23dd84f777d09925bbecb0362081ca528f746", + "hash": "e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99", + "storageAddress": "http://testserver/model/e6a16f5bea8a485f48a8aa8c462155d2d500022a9459c1ff4b3c32acd168ff99/file/" + }, + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "status": "done", + "tag": "substra" + } + }, + { + "traintuple": { + "key": "32070e156eb4f97d85ff8448ea2ab71f4f275ab845159029354e4446aff974e0", + "algo": { + "name": "Logistic regression", + "hash": "6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d", + "storageAddress": "http://testserver/algo/6523012b72bcd0299f709bc6aaa084d2092dddb9a6256fbffa64645478995a1d/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "e3644123451975be20909fcfd9c664a0573d9bfe04c5021625412d78c3536f1c" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 1 + }, + "computePlanID": "32070e156eb4f97d85ff8448ea2ab71f4f275ab845159029354e4446aff974e0", + "inModels": None, + "log": "", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": { + "hash": "0b1ce6f2bd9247a262c3695aa07aad5ef187197f118c73c60a42e176f8f53b98", + "storageAddress": "http://testserver/model/0b1ce6f2bd9247a262c3695aa07aad5ef187197f118c73c60a42e176f8f53b98/file/" + }, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "done", + "tag": "" + }, + "testtuple": { + "key": "", + "algo": None, + "certified": False, + "creator": "", + "dataset": None, + "log": "", + "model": None, + "objective": None, + "status": "", + "tag": "" + } + }, + { + "traintuple": { + "key": "a2171a1c09738c677748346d22d2b5eea47f874a3b4f4b75224674235892de72", + "algo": { + "name": "Random Forest", + "hash": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", + "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/file/" + }, + "creator": "chu-nantesMSP", + "dataset": { + "worker": "chu-nantesMSP", + "keys": [ + "dacc0288138cb50569250f996bbe716ec8968fb334d32f29f174c9e79a224127", + "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" + ], + "openerHash": "8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca", + "perf": 0 + }, + "computePlanID": "", + "inModels": None, + "log": "[00-01-0032-8189cc5]", + "objective": { + "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", + "metrics": { + "hash": "506dacd8800c36e70ad3df7379c9164e03452d700bd2c3edb472e6bd0dc01f2e", + "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" + } + }, + "outModel": None, + "permissions": { + "process": { + "public": True, + "authorizedIDs": [] + } + }, + "rank": 0, + "status": "failed", + "tag": "" + }, + "testtuple": { + "key": "", + "algo": None, + "certified": False, + "creator": "", + "dataset": None, + "log": "", + "model": None, + "objective": None, + "status": "", + "tag": "" + } + } +] diff --git a/backend/substrapp/tests/common.py b/backend/substrapp/tests/common.py new file mode 100644 index 000000000..bc72d2abd --- /dev/null +++ b/backend/substrapp/tests/common.py @@ -0,0 +1,296 @@ +from http.cookies import SimpleCookie +from io import StringIO, BytesIO +import os +import base64 + + +from django.contrib.auth.models import User +from django.core.files.uploadedfile import InMemoryUploadedFile +from rest_framework.test import APIClient + + +# This function helper generate a basic authentication header with given credentials +# Given username and password it returns "Basic GENERATED_TOKEN" +from users.serializers import CustomTokenObtainPairSerializer + + +def generate_basic_auth_header(username, password): + return 'Basic ' + base64.b64encode(f'{username}:{password}'.encode()).decode() + + +def generate_jwt_auth_header(jwt): + return 'JWT ' + jwt + + +class AuthenticatedClient(APIClient): + + def request(self, **kwargs): + + # create user + username = 'substra' + password = 'p@$swr0d44' + user, created = User.objects.get_or_create(username=username) + if created: + user.set_password(password) + user.save() + # simulate login + serializer = CustomTokenObtainPairSerializer(data={ + 'username': username, + 'password': password + }) + + serializer.is_valid() + token = serializer.validated_data + jwt = str(token) + + # simulate right httpOnly cookie and Authorization jwt + jwt_auth_header = generate_jwt_auth_header('.'.join(jwt.split('.')[0:2])) + self.credentials(HTTP_AUTHORIZATION=jwt_auth_header) + self.cookies = SimpleCookie({'signature': jwt.split('.')[2]}) + + return super().request(**kwargs) + + +def get_temporary_text_file(contents, filename): + """ + Creates a temporary text file + + :param contents: contents of the file + :param filename: name of the file + :type contents: str + :type filename: str + """ + f = StringIO() + flength = f.write(contents) + text_file = InMemoryUploadedFile(f, None, filename, 'text', flength, None) + # Setting the file to its start + text_file.seek(0) + return text_file + + +def get_sample_objective(): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + description_content = "Super objective" + description_filename = "description.md" + description = get_temporary_text_file(description_content, description_filename) + + metrics_filename = "metrics.zip" + f = BytesIO(b'') + with open(os.path.join(dir_path, + '../../../fixtures/chunantes/objectives/objective0/metrics.zip'), 'rb') as zip_file: + flength = f.write(zip_file.read()) + metrics = InMemoryUploadedFile(f, None, metrics_filename, + 'application/zip', flength, None) + metrics.seek(0) + + return description, description_filename, metrics, metrics_filename + + +def get_sample_script(): + script_content = "import slidelib\n\ndef read():\n\tpass" + script_filename = "script.py" + script = get_temporary_text_file(script_content, script_filename) + + return script, script_filename + + +def get_sample_datamanager(): + description_content = "description" + description_filename = "description.md" + description = get_temporary_text_file(description_content, description_filename) + data_opener_content = "import slidelib\n\ndef read():\n\tpass" + data_opener_filename = "data_opener.py" + data_opener = get_temporary_text_file(data_opener_content, data_opener_filename) + + return description, description_filename, data_opener, data_opener_filename + + +def get_sample_datamanager2(): + description_content = "description 2" + description_filename = "description2.md" + description = get_temporary_text_file(description_content, description_filename) + data_opener_content = "import os\nimport slidelib\n\ndef read():\n\tpass" + data_opener_filename = "data_opener2.py" + data_opener = get_temporary_text_file(data_opener_content, data_opener_filename) + + return description, description_filename, data_opener, data_opener_filename + + +def get_sample_zip_data_sample(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.zip" + f = BytesIO(b'foo') + with open(os.path.join(dir_path, '../../../fixtures/owkin/datasamples/datasample4/0024900.zip'), 'rb') as zip_file: + flength = f.write(zip_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, + 'application/zip', flength, None) + file.seek(0) + + return file, file_filename + + +def get_sample_zip_data_sample_2(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.zip" + f = BytesIO(b'foo') + with open(os.path.join(dir_path, '../../../fixtures/owkin/datasamples/test/0024901.zip'), 'rb') as zip_file: + flength = f.write(zip_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, + 'application/zip', flength, None) + file.seek(0) + + return file, file_filename + + +def get_sample_tar_data_sample(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.tar.gz" + f = BytesIO() + with open(os.path.join( + dir_path, '../../../fixtures/owkin/datasamples/datasample4/0024900.tar.gz'), 'rb') as tar_file: + flength = f.write(tar_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, 'application/zip', flength, None) + file.seek(0) + + return file, file_filename + + +def get_sample_algo(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.tar.gz" + f = BytesIO() + with open(os.path.join(dir_path, '../../../fixtures/chunantes/algos/algo3/algo.tar.gz'), 'rb') as tar_file: + flength = f.write(tar_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, 'application/tar+gzip', flength, None) + file.seek(0) + + return file, file_filename + + +def get_sample_algo_zip(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.zip" + f = BytesIO() + with open(os.path.join(dir_path, '../../../fixtures/chunantes/algos/algo0/algo.zip'), 'rb') as tar_file: + flength = f.write(tar_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, 'application/zip', flength, None) + file.seek(0) + + return file, file_filename + + +def get_description_algo(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_filename = "file.md" + f = BytesIO() + with open(os.path.join(dir_path, '../../../fixtures/chunantes/algos/algo3/description.md'), 'rb') as desc_file: + flength = f.write(desc_file.read()) + + file = InMemoryUploadedFile(f, None, file_filename, 'application/text', flength, None) + file.seek(0) + + return file, file_filename + + +def get_sample_model(): + model_content = "0.1, 0.2, -1.0" + model_filename = "model.bin" + model = get_temporary_text_file(model_content, model_filename) + + return model, model_filename + + +DEFAULT_PERMISSIONS = { + 'process': { + 'public': True, + 'authorizedIDs': [], + } +} + + +def get_sample_algo_metadata(): + return { + 'owner': 'foo', + 'permissions': DEFAULT_PERMISSIONS, + } + + +def get_sample_objective_metadata(): + return { + 'owner': 'foo', + 'permissions': DEFAULT_PERMISSIONS, + } + + +class FakeMetrics(object): + def __init__(self, filepath='path'): + self.path = filepath + + def save(self, p, f): + return + + def read(self, *args, **kwargs): + return b'foo' + + +class FakeObjective(object): + def __init__(self, filepath='path'): + self.metrics = FakeMetrics(filepath) + + +class FakeOpener(object): + def __init__(self, filepath): + self.path = filepath + self.name = self.path + + +class FakeDataManager(object): + def __init__(self, filepath): + self.data_opener = FakeOpener(filepath) + + +class FakeFilterDataManager(object): + def __init__(self, count): + self.count_value = count + + def count(self): + return self.count_value + + +class FakePath(object): + def __init__(self, filepath): + self.path = filepath + + +class FakeModel(object): + def __init__(self, filepath): + self.file = FakePath(filepath) + + +class FakeAsyncResult(object): + def __init__(self, status=None, successful=True): + if status is not None: + self.status = status + self.success = successful + self.result = {'res': 'result'} + + def successful(self): + return self.success + + +class FakeRequest(object): + def __init__(self, status, content): + self.status_code = status + self.content = content + + +class FakeTask(object): + def __init__(self, task_id): + self.id = task_id diff --git a/backend/substrapp/tests/generate_assets.py b/backend/substrapp/tests/generate_assets.py new file mode 100644 index 000000000..16304e001 --- /dev/null +++ b/backend/substrapp/tests/generate_assets.py @@ -0,0 +1,46 @@ +import os +import json +import substra + + +dir_path = os.path.dirname(__file__) +assets_path = os.path.join(dir_path, 'assets.py') + + +def main(): + + client = substra.Client() + client.add_profile('owkin', 'substra', 'p@$swr0d44', 'http://substra-backend.owkin.xyz:8000', '0.0') + client.login() + + client.set_profile('owkin') + + assets = {} + assets['objective'] = json.dumps(client.list_objective(), indent=4) + assets['datamanager'] = json.dumps(client.list_dataset(), indent=4) + assets['algo'] = json.dumps(client.list_algo(), indent=4) + assets['traintuple'] = json.dumps(client.list_traintuple(), indent=4) + assets['testtuple'] = json.dumps(client.list_testtuple(), indent=4) + + assets['model'] = json.dumps([res for res in client.client.list('model') + if ('traintuple' in res and 'testtuple' in res)], indent=4) + + with open(assets_path, 'w') as f: + f.write('"""\nWARNING\n=======\n\nDO NOT MANUALLY EDIT THIS FILE!\n\n' + 'It is generated using substrapp/tests/generate_assets.py\n\n' + 'In order to update this file:\n' + '1. start a clean instance of substra\n' + '2. run populate.py\n' + '3. run substrapp/tests/generate_assets.py\n"""\n\n') + for k, v in assets.items(): + v = v.replace('substra-backend.owkin.xyz:8000', 'testserver') + v = v.replace('substra-backend.chunantes.xyz:8001', 'testserver') + v = v.replace('true', 'True') + v = v.replace('false', 'False') + v = v.replace('null', 'None') + f.write(f'{k} = {v}') + f.write('\n\n') + + +if __name__ == '__main__': + main() diff --git a/backend/substrapp/tests/query/__init__.py b/backend/substrapp/tests/query/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/substrapp/tests/query/tests_query_algo.py b/backend/substrapp/tests/query/tests_query_algo.py new file mode 100644 index 000000000..fa98326be --- /dev/null +++ b/backend/substrapp/tests/query/tests_query_algo.py @@ -0,0 +1,252 @@ +import os +import shutil +import tempfile + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.models import Objective, Algo +from substrapp.serializers import LedgerAlgoSerializer +from substrapp.utils import get_hash, compute_hash +from substrapp.ledger_utils import LedgerError + +from ..common import get_sample_objective, get_sample_datamanager, \ + get_sample_algo, get_sample_algo_zip, AuthenticatedClient, \ + get_sample_algo_metadata + + +MEDIA_ROOT = tempfile.mkdtemp() + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class AlgoQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + self.algo, self.algo_filename = get_sample_algo() + self.algo_zip, self.algo_filename_zip = get_sample_algo_zip() + + self.data_description, self.data_description_filename, self.data_data_opener, \ + self.data_opener_filename = get_sample_datamanager() + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def add_default_objective(self): + Objective.objects.create(description=self.objective_description, + metrics=self.objective_metrics) + + def get_default_algo_data(self): + expected_hash = get_hash(self.algo) + + data = { + 'file': self.algo, + 'description': self.data_description, # fake it + 'name': 'super top algo', + 'objective_key': get_hash(self.objective_description), + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + + return expected_hash, data + + def get_default_algo_data_zip(self): + expected_hash = get_hash(self.algo_zip) + + data = { + 'file': self.algo_zip, + 'description': self.data_description, # fake it + 'name': 'super top algo', + 'objective_key': get_hash(self.objective_description), + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + + return expected_hash, data + + def test_add_algo_sync_ok(self): + self.add_default_objective() + pkhash, data = self.get_default_algo_data_zip() + + url = reverse('substrapp:algo-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = {'pkhash': pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], True) + self.assertEqual(r['description'], + f'http://testserver/media/algos/{r["pkhash"]}/{self.data_description_filename}') + self.assertEqual(r['file'], + f'http://testserver/media/algos/{r["pkhash"]}/{self.algo_filename_zip}') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_algo_no_sync_ok(self): + self.add_default_objective() + pkhash, data = self.get_default_algo_data() + + url = reverse('substrapp:algo-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = { + 'message': 'Algo added in local db waiting for validation.' + 'The substra network has been notified for adding this Algo' + } + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], False) + self.assertEqual(r['description'], + f'http://testserver/media/algos/{r["pkhash"]}/{self.data_description_filename}') + self.assertEqual(r['file'], + f'http://testserver/media/algos/{r["pkhash"]}/{self.algo_filename}') + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_algo_ko(self): + url = reverse('substrapp:algo-list') + + # non existing associated objective + data = { + 'file': self.algo, + 'description': self.data_description, + 'name': 'super top algo', + 'objective_key': 'non existing objectivexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: + mcreate.side_effect = LedgerError('Fail to add algo. Objective does not exist') + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertIn('does not exist', r['message']) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + Objective.objects.create(description=self.objective_description, + metrics=self.objective_metrics) + + # missing local storage field + data = { + 'name': 'super top algo', + 'objective_key': get_hash(self.objective_description), + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # missing ledger field + data = { + 'file': self.algo, + 'description': self.data_description, + 'objective_key': get_hash(self.objective_description), + } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_algo_no_version(self): + + self.add_default_objective() + + url = reverse('substrapp:algo-list') + + data = { + 'file': self.algo, + 'description': self.data_description, + 'name': 'super top algo', + 'objective_key': get_hash(self.objective_description), + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + response = self.client.post(url, data, format='multipart') + r = response.json() + + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_algo_wrong_version(self): + + self.add_default_objective() + + url = reverse('substrapp:algo-list') + + data = { + 'file': self.algo, + 'description': self.data_description, + 'name': 'super top algo', + 'objective_key': get_hash(self.objective_description), + 'permissions_public': True, + 'permissions_authorized_ids': [], + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_get_algo_files(self): + algo = Algo.objects.create(file=self.algo) + with mock.patch('substrapp.views.utils.get_owner', return_value='foo'), \ + mock.patch('substrapp.views.utils.get_object_from_ledger') \ + as mget_object_from_ledger: + mget_object_from_ledger.return_value = get_sample_algo_metadata() + + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + response = self.client.get(f'/algo/{algo.pkhash}/file/', **extra) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(algo.pkhash, compute_hash(response.getvalue())) + + def test_get_algo_files_no_version(self): + algo = Algo.objects.create(file=self.algo) + response = self.client.get(f'/algo/{algo.pkhash}/file/') + r = response.json() + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_get_algo_files_wrong_version(self): + algo = Algo.objects.create(file=self.algo) + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + response = self.client.get(f'/algo/{algo.pkhash}/file/', **extra) + r = response.json() + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) diff --git a/backend/substrapp/tests/query/tests_query_datamanager.py b/backend/substrapp/tests/query/tests_query_datamanager.py new file mode 100644 index 000000000..1282eab31 --- /dev/null +++ b/backend/substrapp/tests/query/tests_query_datamanager.py @@ -0,0 +1,136 @@ +import os +import shutil +import tempfile + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + + +from substrapp.utils import get_hash + +from ..common import get_sample_datamanager, AuthenticatedClient + +MEDIA_ROOT = tempfile.mkdtemp() + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class DataManagerQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.data_description, self.data_description_filename, self.data_data_opener, \ + self.data_opener_filename = get_sample_datamanager() + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def get_default_datamanager_data(self): + expected_hash = get_hash(self.data_data_opener) + data = { + 'name': 'slide opener', + 'type': 'images', + 'permissions_public': True, + 'permissions_authorized_ids': [], + 'objective_key': '', + 'description': self.data_description, + 'data_opener': self.data_data_opener + } + return expected_hash, data + + def test_add_datamanager_sync_ok(self): + + pkhash, data = self.get_default_datamanager_data() + + url = reverse('substrapp:data_manager-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = {'pkhash': pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], True) + self.assertEqual(r['description'], + f'http://testserver/media/datamanagers/{r["pkhash"]}/{self.data_description_filename}') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_datamanager_no_sync_ok(self): + + pkhash, data = self.get_default_datamanager_data() + + url = reverse('substrapp:data_manager-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = { + 'message': 'DataManager added in local db waiting for validation.' + 'The substra network has been notified for adding this DataManager' + } + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], False) + self.assertEqual(r['description'], + f'http://testserver/media/datamanagers/{r["pkhash"]}/{self.data_description_filename}') + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_datamanager_ko(self): + data = {'name': 'toto'} + + url = reverse('substrapp:data_manager-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_datamanager_no_version(self): + + _, data = self.get_default_datamanager_data() + + url = reverse('substrapp:data_manager-list') + + response = self.client.post(url, data, format='multipart') + r = response.json() + + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_datamanager_wrong_version(self): + + _, data = self.get_default_datamanager_data() + + url = reverse('substrapp:data_manager-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) diff --git a/backend/substrapp/tests/query/tests_query_datasample.py b/backend/substrapp/tests/query/tests_query_datasample.py new file mode 100644 index 000000000..83e8d8fa9 --- /dev/null +++ b/backend/substrapp/tests/query/tests_query_datasample.py @@ -0,0 +1,510 @@ +import os +import shutil +import tempfile +import zipfile +from unittest.mock import MagicMock + +import mock +from django.core.files import File +from django.core.files.uploadedfile import InMemoryUploadedFile + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.models import DataManager, DataSample +from substrapp.serializers import LedgerDataSampleSerializer, DataSampleSerializer + +from substrapp.utils import get_hash, get_dir_hash, store_datasamples_archive +from substrapp.ledger_utils import LedgerError, LedgerTimeout +from substrapp.views import DataSampleViewSet + +from ..common import get_sample_datamanager, get_sample_zip_data_sample, get_sample_script, \ + get_sample_datamanager2, get_sample_tar_data_sample, get_sample_zip_data_sample_2, AuthenticatedClient + +MEDIA_ROOT = tempfile.mkdtemp() + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class DataSampleQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.script, self.script_filename = get_sample_script() + self.data_file, self.data_file_filename = get_sample_zip_data_sample() + self.data_file_2, self.data_file_filename_2 = get_sample_zip_data_sample_2() + self.data_tar_file, self.data_tar_file_filename = get_sample_tar_data_sample() + + self.data_description, self.data_description_filename, self.data_data_opener, \ + self.data_opener_filename = get_sample_datamanager() + + self.data_description2, self.data_description_filename2, self.data_data_opener2, \ + self.data_opener_filename2 = get_sample_datamanager2() + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def add_default_data_manager(self): + DataManager.objects.create(name='slide opener', + description=self.data_description, + data_opener=self.data_data_opener) + + DataManager.objects.create(name='slide opener', + description=self.data_description2, + data_opener=self.data_data_opener2) + + def get_default_datasample_data(self): + expected_hash = get_dir_hash(self.data_file.file) + self.data_file.file.seek(0) + data = { + 'file': self.data_file, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + + return expected_hash, data + + def test_add_data_sample_sync_ok(self): + + self.add_default_data_manager() + pkhash, data = self.get_default_datasample_data() + + url = reverse('substrapp:data_sample-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.datasample.util.create_ledger_assets') as mcreate_ledger_assets: + mcreate_ledger_assets.return_value = { + 'pkhash': pkhash, + 'validated': True + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r[0]['pkhash'], pkhash) + self.assertEqual(r[0]['validated'], True) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_bulk_add_data_sample_sync_ok(self): + + self.add_default_data_manager() + + url = reverse('substrapp:data_sample-list') + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock2 = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock2.name = 'bar.zip' + file_mock.read = MagicMock(return_value=self.data_file.read()) + file_mock2.read = MagicMock(return_value=self.data_file_2.read()) + + data = { + file_mock.name: file_mock, + file_mock2.name: file_mock2, + 'data_manager_keys': [get_hash(self.data_data_opener), get_hash(self.data_data_opener2)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.datasample.util.create_ledger_assets') as mcreate_ledger_assets: + self.data_file.seek(0) + self.data_file_2.seek(0) + ledger_data = {'pkhash': [get_dir_hash(file_mock), get_dir_hash(file_mock2)], 'validated': True} + mcreate_ledger_assets.return_value = ledger_data + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(len(r), 2) + self.assertEqual(r[0]['pkhash'], get_dir_hash(file_mock)) + self.assertTrue(r[0]['path'].endswith(f'/datasamples/{get_dir_hash(file_mock)}')) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_data_sample_no_sync_ok(self): + self.add_default_data_manager() + pkhash, data = self.get_default_datasample_data() + + url = reverse('substrapp:data_sample-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.datasample.util.create_ledger_assets') as mcreate_ledger_assets: + mcreate_ledger_assets.return_value = '' + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r[0]['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_data_sample_ko(self): + url = reverse('substrapp:data_sample-list') + + # missing datamanager + data = {'data_manager_keys': ['toto']} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual( + r['message'], + "One or more datamanager keys provided do not exist in local database. " + "Please create them before. DataManager keys: ['toto']") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.add_default_data_manager() + + # missing local storage field + data = {'data_manager_keys': [get_hash(self.data_description)], + 'test_only': True, } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # missing ledger field + data = {'data_manager_keys': [get_hash(self.data_description)], + 'file': self.script, } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_ko_already_exists(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=self.data_file.file.read()) + file_mock.open = MagicMock(return_value=file_mock) + + _, datasamples_path_from_file = store_datasamples_archive(file_mock) + + d = DataSample(path=datasamples_path_from_file) + # trigger pre save + d.save() + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile: + mis_zipfile.return_value = True + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message'], + [[{'pkhash': ['data sample with this pkhash already exists.']}]]) + self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) + + def test_add_data_sample_ko_not_a_zip(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=File) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=b'foo') + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message'], 'Archive must be zip or tar.*') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_ko_408(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=self.data_file.file.read()) + file_mock.open = MagicMock(return_value=file_mock) + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + mcreate.side_effect = LedgerTimeout('Timeout') + mis_zipfile.return_value = True + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual( + r['message'], + {'pkhash': [get_dir_hash(file_mock)], 'validated': False}) + self.assertEqual(response.status_code, status.HTTP_408_REQUEST_TIMEOUT) + + def test_bulk_add_data_sample_ko_408(self): + + self.add_default_data_manager() + + url = reverse('substrapp:data_sample-list') + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock2 = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock2.name = 'bar.zip' + file_mock.read = MagicMock(return_value=self.data_file.read()) + file_mock2.read = MagicMock(return_value=self.data_file_2.read()) + + data = { + file_mock.name: file_mock, + file_mock2.name: file_mock2, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.datasample.DataSampleSerializer.get_validators') as mget_validators, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + mget_validators.return_value = [] + self.data_file.seek(0) + self.data_tar_file.seek(0) + mcreate.side_effect = LedgerTimeout('Timeout') + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message']['validated'], False) + self.assertEqual(DataSample.objects.count(), 2) + self.assertEqual(response.status_code, status.HTTP_408_REQUEST_TIMEOUT) + + def test_bulk_add_data_sample_ko_same_pkhash(self): + + self.add_default_data_manager() + + url = reverse('substrapp:data_sample-list') + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock2 = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock2.name = 'bar.tar.gz' + file_mock.read = MagicMock(return_value=self.data_file.read()) + file_mock2.read = MagicMock(return_value=self.data_tar_file.read()) + + data = { + file_mock.name: file_mock, + file_mock2.name: file_mock2, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.datasample.DataSampleSerializer.get_validators') as mget_validators, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + mget_validators.return_value = [] + self.data_file.seek(0) + self.data_tar_file.seek(0) + ledger_data = {'pkhash': [get_dir_hash(file_mock), get_dir_hash(file_mock2)], 'validated': False} + mcreate.return_value = ledger_data, status.HTTP_408_REQUEST_TIMEOUT + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(DataSample.objects.count(), 0) + self.assertEqual( + r['message'], + f'Your data sample archives contain same files leading to same pkhash, ' + f'please review the content of your achives. ' + f'Archives {file_mock2.name} and {file_mock.name} are the same') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_ko_400(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=self.data_file.file.read()) + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + mcreate.side_effect = LedgerError('Failed') + mis_zipfile.return_value = True + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message'], 'Failed') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_ko_serializer_invalid(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=self.data_file.read()) + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ + mock.patch.object(DataSampleViewSet, 'get_serializer') as mget_serializer: + mocked_serializer = MagicMock(DataSampleSerializer) + mocked_serializer.is_valid.return_value = True + mocked_serializer.save.side_effect = Exception('Failed') + mget_serializer.return_value = mocked_serializer + + mis_zipfile.return_value = True + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message'], "Failed") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_ko_ledger_invalid(self): + url = reverse('substrapp:data_sample-list') + + self.add_default_data_manager() + + file_mock = MagicMock(spec=InMemoryUploadedFile) + file_mock.name = 'foo.zip' + file_mock.read = MagicMock(return_value=self.data_file.file.read()) + + data = { + 'file': file_mock, + 'data_manager_keys': [get_hash(self.data_data_opener)], + 'test_only': True, + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ + mock.patch('substrapp.views.datasample.LedgerDataSampleSerializer', + spec=True) as mLedgerDataSampleSerializer: + mocked_LedgerDataSampleSerializer = MagicMock() + mocked_LedgerDataSampleSerializer.is_valid.return_value = False + mocked_LedgerDataSampleSerializer.errors = 'Failed' + mLedgerDataSampleSerializer.return_value = mocked_LedgerDataSampleSerializer + + mis_zipfile.return_value = True + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['message'], "[ErrorDetail(string='Failed', code='invalid')]") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_data_sample_no_version(self): + + self.add_default_data_manager() + + url = reverse('substrapp:data_sample-list') + + data = { + 'file': self.data_file, + 'data_manager_keys': [get_hash(self.data_description)], + 'test_only': True, + } + response = self.client.post(url, data, format='multipart') + r = response.json() + + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_data_sample_wrong_version(self): + + self.add_default_data_manager() + + url = reverse('substrapp:data_sample-list') + + data = { + 'file': self.script, + 'data_manager_keys': ['XXXX'], + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_bulk_update_data(self): + + # add associated data opener + datamanager = DataManager.objects.create(name='slide opener', + description=self.data_description, + data_opener=self.data_data_opener) + datamanager2 = DataManager.objects.create(name='slide opener 2', + description=self.data_description2, + data_opener=self.data_data_opener2) + + d = DataSample(path=self.data_file) + # trigger pre save + d.save() + + url = reverse('substrapp:data_sample-bulk-update') + + data = { + 'data_manager_keys': [datamanager.pkhash, datamanager2.pkhash], + 'data_sample_keys': [d.pkhash], + } + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch( + 'substrapp.serializers.ledger.datasample.util.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = {'keys': [ + d.pkhash]} + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r['keys'], [d.pkhash]) + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/backend/substrapp/tests/query/tests_query_objective.py b/backend/substrapp/tests/query/tests_query_objective.py new file mode 100644 index 000000000..c61bdc80c --- /dev/null +++ b/backend/substrapp/tests/query/tests_query_objective.py @@ -0,0 +1,258 @@ +import os +import shutil +import tempfile + +import mock + + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.models import Objective, DataManager +from substrapp.utils import get_hash, compute_hash + +from ..common import get_sample_objective, get_sample_datamanager, \ + get_temporary_text_file, AuthenticatedClient, get_sample_objective_metadata + +MEDIA_ROOT = tempfile.mkdtemp() + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class ObjectiveQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + self.data_description, self.data_description_filename, self.data_data_opener, \ + self.data_opener_filename = get_sample_datamanager() + + self.test_data_sample_keys = [ + '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379', + '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389'] + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def add_default_data_manager(self): + DataManager.objects.create(name='slide opener', + description=self.data_description, + data_opener=self.data_data_opener) + + def get_default_objective_data(self): + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + expected_hash = get_hash(self.objective_description) + data = { + 'name': 'tough objective', + 'test_data_manager_key': get_hash(self.data_data_opener), + 'test_data_sample_keys': self.test_data_sample_keys, + 'description': self.objective_description, + 'metrics': self.objective_metrics, + 'permissions_public': True, + 'permissions_authorized_ids': [], + 'metrics_name': 'accuracy' + } + return expected_hash, data + + def test_add_objective_sync_ok(self): + self.add_default_data_manager() + pkhash, data = self.get_default_objective_data() + + url = reverse('substrapp:objective-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = {'pkhash': pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], True) + self.assertEqual(r['description'], + f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_description_filename}') + self.assertEqual(r['metrics'], + f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_metrics_filename}') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_objective_no_sync_ok(self): + self.add_default_data_manager() + pkhash, data = self.get_default_objective_data() + + url = reverse('substrapp:objective-list') + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = { + 'message': 'Objective added in local db waiting for validation.' + 'The substra network has been notified for adding this Objective' + } + response = self.client.post(url, data, format='multipart', **extra) + + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(r['validated'], False) + self.assertEqual(r['description'], + f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_description_filename}') + self.assertEqual(r['metrics'], + f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_metrics_filename}') + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_objective_conflict(self): + self.add_default_data_manager() + + pkhash, data = self.get_default_objective_data() + + url = reverse('substrapp:objective-list') + + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = {'pkhash': pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + # XXX reload data as the previous call to post change it + _, data = self.get_default_objective_data() + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) + self.assertEqual(r['pkhash'], pkhash) + + def test_add_objective_ko(self): + url = reverse('substrapp:objective-list') + + data = {'name': 'empty objective'} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + data = {'metrics': self.objective_metrics, + 'description': self.objective_description} + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_objective_no_version(self): + url = reverse('substrapp:objective-list') + + description_content = 'My Super top objective' + metrics_content = 'def metrics():\n\tpass' + + description = get_temporary_text_file(description_content, + 'description.md') + metrics = get_temporary_text_file(metrics_content, 'metrics.py') + + data = { + 'name': 'tough objective', + 'test_data_sample_keys': self.test_data_sample_keys, + 'description': description, + 'metrics': metrics, + } + + response = self.client.post(url, data, format='multipart') + r = response.json() + + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_objective_wrong_version(self): + url = reverse('substrapp:objective-list') + + description_content = 'My Super top objective' + metrics_content = 'def metrics():\n\tpass' + + description = get_temporary_text_file(description_content, + 'description.md') + metrics = get_temporary_text_file(metrics_content, 'metrics.py') + + data = { + 'name': 'tough objective', + 'test_data_sample_keys': self.test_data_sample_keys, + 'description': description, + 'metrics': metrics, + } + + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_get_objective_metrics(self): + objective = Objective.objects.create( + description=self.objective_description, + metrics=self.objective_metrics) + + with mock.patch('substrapp.views.utils.get_owner', return_value='foo'), \ + mock.patch('substrapp.views.utils.get_object_from_ledger') \ + as mget_object_from_ledger: + mget_object_from_ledger.return_value = get_sample_objective_metadata() + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + response = self.client.get( + f'/objective/{objective.pkhash}/metrics/', **extra) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotEqual(objective.pkhash, + compute_hash(response.getvalue())) + self.assertEqual(self.objective_metrics_filename, + response.filename) + + def test_get_objective_metrics_no_version(self): + objective = Objective.objects.create( + description=self.objective_description, + metrics=self.objective_metrics) + response = self.client.get(f'/objective/{objective.pkhash}/metrics/') + r = response.json() + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_get_objective_metrics_wrong_version(self): + objective = Objective.objects.create( + description=self.objective_description, + metrics=self.objective_metrics) + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + response = self.client.get(f'/objective/{objective.pkhash}/metrics/', + **extra) + r = response.json() + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) diff --git a/backend/substrapp/tests/query/tests_query_tuples.py b/backend/substrapp/tests/query/tests_query_tuples.py new file mode 100644 index 000000000..d46298992 --- /dev/null +++ b/backend/substrapp/tests/query/tests_query_tuples.py @@ -0,0 +1,307 @@ +import os +import shutil +import tempfile + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.models import Objective +from substrapp.utils import get_hash + +from ..common import get_sample_objective, AuthenticatedClient + +MEDIA_ROOT = tempfile.mkdtemp() + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class TraintupleQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + self.train_data_sample_keys = ['5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'] + self.fake_key = '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088' + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def test_add_traintuple_sync_ok(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:traintuple-list') + + data = { + 'train_data_sample_keys': self.train_data_sample_keys, + 'algo_key': self.fake_key, + 'data_manager_key': self.fake_key, + 'objective_key': self.fake_key, + 'rank': -1, + 'compute_plan_id': self.fake_key, + 'in_models_keys': [self.fake_key]} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.traintuple.util.invoke_ledger') as minvoke_ledger, \ + mock.patch('substrapp.views.traintuple.query_ledger') as mquery_ledger: + + raw_pkhash = 'traintuple_pkhash'.encode('utf-8').hex() + mquery_ledger.return_value = {'key': raw_pkhash} + minvoke_ledger.return_value = {'pkhash': raw_pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_traintuple_no_sync_ok(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:traintuple-list') + + data = { + 'train_data_sample_keys': self.train_data_sample_keys, + 'algo_key': self.fake_key, + 'data_manager_key': self.fake_key, + 'objective_key': self.fake_key, + 'rank': -1, + 'compute_plan_id': self.fake_key, + 'in_models_keys': [self.fake_key]} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.traintuple.util.invoke_ledger') as minvoke_ledger, \ + mock.patch('substrapp.views.traintuple.query_ledger') as mquery_ledger: + + raw_pkhash = 'traintuple_pkhash'.encode('utf-8').hex() + mquery_ledger.return_value = {'key': raw_pkhash} + minvoke_ledger.return_value = None + + response = self.client.post(url, data, format='multipart', **extra) + + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_traintuple_ko(self): + url = reverse('substrapp:traintuple-list') + + data = { + 'train_data_sample_keys': self.train_data_sample_keys, + 'model_key': self.fake_key + } + + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertIn('This field may not be null.', r['algo_key']) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + Objective.objects.create(description=self.objective_description, + metrics=self.objective_metrics) + data = {'objective': get_hash(self.objective_description)} + response = self.client.post(url, data, format='multipart', **extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_traintuple_no_version(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:traintuple-list') + + data = { + 'train_data_sample_keys': self.train_data_sample_keys, + 'datamanager_key': self.fake_key, + 'model_key': self.fake_key, + 'algo_key': self.fake_key} + + response = self.client.post(url, data, format='multipart') + r = response.json() + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_traintuple_wrong_version(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:traintuple-list') + + data = { + 'train_data_sample_keys': self.train_data_sample_keys, + 'datamanager_key': self.fake_key, + 'model_key': self.fake_key, + 'algo_key': self.fake_key} + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER_SYNC_ENABLED=True) +class TesttupleQueryTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + self.test_data_sample_keys = ['5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'] + self.fake_key = '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088' + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def test_add_testtuple_sync_ok(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:testtuple-list') + + data = { + 'test_data_sample_keys': self.test_data_sample_keys, + 'traintuple_key': self.fake_key, + 'data_manager_key': self.fake_key} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.testtuple.util.invoke_ledger') as minvoke_ledger, \ + mock.patch('substrapp.views.testtuple.query_ledger') as mquery_ledger: + + raw_pkhash = 'testtuple_pkhash'.encode('utf-8').hex() + mquery_ledger.return_value = {'key': raw_pkhash} + minvoke_ledger.return_value = {'pkhash': raw_pkhash} + + response = self.client.post(url, data, format='multipart', **extra) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + @override_settings(LEDGER_SYNC_ENABLED=False) + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_add_testtuple_no_sync_ok(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:testtuple-list') + + data = { + 'test_data_sample_keys': self.test_data_sample_keys, + 'traintuple_key': self.fake_key, + 'data_manager_key': self.fake_key} + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + with mock.patch('substrapp.serializers.ledger.testtuple.util.invoke_ledger') as minvoke_ledger, \ + mock.patch('substrapp.views.testtuple.query_ledger') as mquery_ledger: + + raw_pkhash = 'testtuple_pkhash'.encode('utf-8').hex() + mquery_ledger.return_value = {'key': raw_pkhash} + minvoke_ledger.return_value = None + + response = self.client.post(url, data, format='multipart', **extra) + + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + def test_add_testtuple_ko(self): + url = reverse('substrapp:testtuple-list') + + data = { + 'test_data_sample_keys': self.test_data_sample_keys, + } + + extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertIn('This field may not be null.', r['traintuple_key']) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_add_testtuple_no_version(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:testtuple-list') + + data = { + 'test_data_sample_keys': self.test_data_sample_keys, + 'traintuple_key': self.fake_key, + 'data_manager_key': self.fake_key} + + response = self.client.post(url, data, format='multipart') + r = response.json() + self.assertEqual(r, {'detail': 'A version is required.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_add_testtuple_wrong_version(self): + # Add associated objective + description, _, metrics, _ = get_sample_objective() + Objective.objects.create(description=description, + metrics=metrics) + # post data + url = reverse('substrapp:testtuple-list') + + data = { + 'test_data_sample_keys': self.test_data_sample_keys, + 'traintuple_key': self.fake_key, + 'data_manager_key': self.fake_key} + + extra = { + 'HTTP_ACCEPT': 'application/json;version=-1.0', + } + + response = self.client.post(url, data, format='multipart', **extra) + r = response.json() + self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) + self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) diff --git a/substrabac/substrapp/tests/tests_exception.py b/backend/substrapp/tests/tests_exception.py similarity index 59% rename from substrabac/substrapp/tests/tests_exception.py rename to backend/substrapp/tests/tests_exception.py index f8ea10b3f..993e2ba26 100644 --- a/substrabac/substrapp/tests/tests_exception.py +++ b/backend/substrapp/tests/tests_exception.py @@ -2,37 +2,17 @@ import json import docker from django.test import TestCase -from substrapp.generate_exceptions_map import exception_tree, find_exception, MODULES -from substrapp.exception_handler import compute_error_code, get_exception_code +from substrapp.tasks.exception_handler import compute_error_code, get_exception_code, generate_exceptions_map class ExceptionTests(TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - def test_exception_map(self): - # Build the exception map from local configuration - exceptions_classes = set() - - # Get exceptions of modules - for errors_module in MODULES: - exceptions_classes.update(find_exception(errors_module)) - - # Get exceptions from python - exception_tree(BaseException, exceptions_classes) - # Build the exception map - exception_map = dict() - for code_exception, exception_name in enumerate(exceptions_classes, start=1): - exception_map[exception_name] = f'{code_exception:04d}' - + exception_map = generate_exceptions_map(append=False) # Exception map reference - EXCEPTION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../exceptions.json') + EXCEPTION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../tasks/exceptions.json') reference_exception_map = json.load(open(EXCEPTION_PATH)) self.assertTrue(set(reference_exception_map.keys()).issubset(set(exception_map.keys()))) diff --git a/backend/substrapp/tests/tests_misc.py b/backend/substrapp/tests/tests_misc.py new file mode 100644 index 000000000..cae6b809a --- /dev/null +++ b/backend/substrapp/tests/tests_misc.py @@ -0,0 +1,96 @@ +from django.test import TestCase + +from mock import patch +from substrapp.tasks.utils import get_cpu_sets, get_gpu_sets + +from substrapp.ledger_utils import LedgerNotFound, LedgerBadResponse + +from substrapp.ledger_utils import get_object_from_ledger, log_fail_tuple, log_start_tuple, \ + log_success_tuple, query_tuples + + +class MockDevice(): + """A mock device to temporarily suppress output to stdout + Similar to UNIX /dev/null. + """ + + def write(self, s): + pass + + +class MiscTests(TestCase): + """Misc tests""" + + def test_cpu_sets(self): + cpu_count = 16 + for concurrency in range(1, cpu_count + 1, 1): + self.assertEqual(concurrency, + len(get_cpu_sets(cpu_count, concurrency))) + + def test_gpu_sets(self): + gpu_list = ['0', '1'] + for concurrency in range(1, len(gpu_list) + 1, 1): + self.assertEqual(concurrency, + len(get_gpu_sets(gpu_list, concurrency))) + + self.assertFalse(get_gpu_sets([], concurrency)) + + def test_get_object_from_ledger(self): + with patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.side_effect = LedgerNotFound('Not Found') + self.assertRaises(LedgerNotFound, get_object_from_ledger, 'pk', 'fake_query') + + with patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.side_effect = LedgerBadResponse('Bad Response') + self.assertRaises(LedgerBadResponse, get_object_from_ledger, 'pk', 'fake_query') + + with patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.return_value = {'key': 'pk'} + data = get_object_from_ledger('pk', 'good_query') + self.assertEqual(data['key'], 'pk') + + def test_log_fail_tuple(self): + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + log_fail_tuple('traintuple', 'pk', 'error_msg') + + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + log_fail_tuple('testtuple', 'pk', 'error_msg') + + def test_log_start_tuple(self): + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + log_start_tuple('traintuple', 'pk') + + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + log_start_tuple('testtuple', 'pk') + + def test_log_success_tuple(self): + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + res = { + 'end_model_file_hash': 'hash', + 'end_model_file': 'storageAddress', + 'global_perf': '0.99', + 'job_task_log': 'log', + } + log_success_tuple('traintuple', 'pk', res) + + with patch('substrapp.ledger_utils.invoke_ledger') as minvoke_ledger: + minvoke_ledger.return_value = None + res = { + 'global_perf': '0.99', + 'job_task_log': 'log', + } + log_success_tuple('testtuple', 'pk', res) + + def test_query_tuples(self): + with patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.return_value = None + query_tuples('traintuple', 'data_owner') + + with patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.return_value = None + query_tuples('testtuple', 'data_owner') diff --git a/substrabac/substrapp/tests/tests_model.py b/backend/substrapp/tests/tests_model.py similarity index 93% rename from substrabac/substrapp/tests/tests_model.py rename to backend/substrapp/tests/tests_model.py index a8442aac8..5dedfc6aa 100644 --- a/substrabac/substrapp/tests/tests_model.py +++ b/backend/substrapp/tests/tests_model.py @@ -19,10 +19,7 @@ class ModelTests(TestCase): """Model tests""" def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_create_objective(self): description, _, metrics, _ = get_sample_objective() @@ -44,7 +41,7 @@ def test_create_datamanager(self): def test_create_data(self): dir_path = os.path.dirname(os.path.realpath(__file__)) - path = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/train/0024308') + path = os.path.join(dir_path, '../../../fixtures/chunantes/datasamples/train/0024308') data_sample = DataSample.objects.create(path=path) self.assertEqual(data_sample.pkhash, dirhash(path, 'sha256')) self.assertFalse(data_sample.validated) diff --git a/backend/substrapp/tests/tests_password_validation.py b/backend/substrapp/tests/tests_password_validation.py new file mode 100644 index 000000000..330a2d87a --- /dev/null +++ b/backend/substrapp/tests/tests_password_validation.py @@ -0,0 +1,43 @@ +from django.test import TestCase +from django.core.exceptions import ValidationError +from libs.maximumLengthValidator import MaximumLengthValidator +from libs.zxcvbnValidator import ZxcvbnValidator + + +class PasswordValidationTests(TestCase): + + def setUp(self): + self.max_len_validator = MaximumLengthValidator() + self.complexity_validator = ZxcvbnValidator() + + def test_password_invalid_length(self): + password_short = "aaa" + password_too_long = ''.join(["a" * 65]) + + # short password OK + try: + self.max_len_validator.validate(password_short) + except Exception: + self.fail(f"Password validation should succeed when the password is not too long.") + + # too long password NOT OK + self.assertRaisesRegexp(ValidationError, + "This password is too long. It must contain a maximum of 64 characters.", + self.max_len_validator.validate, + password_too_long) + + def test_password_complexity(self): + password_not_complex = "abc" + password_complex = "p@$swr0d44" + + # complex password OK + try: + self.complexity_validator.validate(password_complex) + except Exception: + self.fail(f"Password validation should succeed when the password is complex enough.") + + # easy-to-guess password NOT OK + self.assertRaisesRegexp(ValidationError, + "This password is not complex enough.*", + self.complexity_validator.validate, + password_not_complex) diff --git a/backend/substrapp/tests/tests_tasks.py b/backend/substrapp/tests/tests_tasks.py new file mode 100644 index 000000000..bc4ed86ec --- /dev/null +++ b/backend/substrapp/tests/tests_tasks.py @@ -0,0 +1,707 @@ +import os +import shutil +import mock +import uuid +from unittest.mock import MagicMock + +from django.test import override_settings +from rest_framework import status +from rest_framework.test import APITestCase +from django_celery_results.models import TaskResult + +from substrapp.models import DataSample +from substrapp.ledger_utils import LedgerStatusError +from substrapp.utils import store_datasamples_archive +from substrapp.utils import compute_hash, get_remote_file_content, get_hash, create_directory +from substrapp.tasks.utils import ResourcesManager, compute_docker +from substrapp.tasks.tasks import (build_subtuple_folders, get_algo, get_model, get_models, get_objective, put_opener, + put_model, put_models, put_algo, put_metric, put_data_sample, prepare_task, do_task, + compute_task, remove_subtuple_materials, prepare_materials) + +from .common import (get_sample_algo, get_sample_script, get_sample_zip_data_sample, get_sample_tar_data_sample, + get_sample_model) +from .common import FakeObjective, FakeDataManager, FakeModel, FakeRequest +from . import assets +from node.models import OutgoingNode + +import zipfile +import docker +MEDIA_ROOT = "/tmp/unittests_tasks/" +# MEDIA_ROOT = tempfile.mkdtemp() + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class TasksTests(APITestCase): + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.subtuple_path = MEDIA_ROOT + + self.script, self.script_filename = get_sample_script() + + self.algo, self.algo_filename = get_sample_algo() + self.data_sample, self.data_sample_filename = get_sample_zip_data_sample() + self.data_sample_tar, self.data_sample_tar_filename = get_sample_tar_data_sample() + self.model, self.model_filename = get_sample_model() + + self.ResourcesManager = ResourcesManager() + + @classmethod + def setUpTestData(cls): + cls.outgoing_node = OutgoingNode.objects.create(node_id="external_node_id", secret="s3cr37") + cls.outgoing_node_traintuple = OutgoingNode.objects.create(node_id=assets.traintuple[1]['creator'], + secret="s3cr37") + if assets.traintuple[1]['creator'] != assets.algo[0]['owner']: + cls.outgoing_node_algo = OutgoingNode.objects.create(node_id=assets.algo[0]['owner'], secret="s3cr37") + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def test_create_directory(self): + directory = './test/' + create_directory(directory) + self.assertTrue(os.path.exists(directory)) + remove_subtuple_materials(directory) + self.assertFalse(os.path.exists(directory)) + + # Remove a second time, it should not raise exception + try: + remove_subtuple_materials(directory) + except Exception: + self.fail('`remove_subtuple_materials` raised Exception unexpectedly!') + + def test_get_remote_file_content(self): + content = str(self.script.read()) + pkhash = compute_hash(content) + remote_file = {'storageAddress': 'localhost', + 'hash': pkhash, + 'owner': 'external_node_id', + } + + with mock.patch('substrapp.utils.get_owner') as get_owner,\ + mock.patch('substrapp.utils.requests.get') as request_get: + get_owner.return_value = 'external_node_id' + request_get.return_value = FakeRequest(content=content, status=status.HTTP_200_OK) + + content_remote = get_remote_file_content(remote_file, 'external_node_id', pkhash) + self.assertEqual(content_remote, content) + + with mock.patch('substrapp.utils.get_owner') as get_owner,\ + mock.patch('substrapp.utils.requests.get') as request_get: + get_owner.return_value = 'external_node_id' + request_get.return_value = FakeRequest(content=content, status=status.HTTP_200_OK) + + with self.assertRaises(Exception): + # contents (by pkhash) are different + get_remote_file_content(remote_file, 'external_node_id', 'fake_pkhash') + + def test_Ressource_Manager(self): + + self.assertTrue(isinstance(self.ResourcesManager.memory_limit_mb(), int)) + + cpu_set, gpu_set = self.ResourcesManager.get_cpu_gpu_sets() + self.assertIn(cpu_set, self.ResourcesManager._ResourcesManager__cpu_sets) + + if gpu_set is not None: + self.assertIn(gpu_set, self.ResourcesManager._ResourcesManager__gpu_sets) + + def test_put_algo_tar(self): + algo_content = self.algo.read() + subtuple_key = get_hash(self.algo) + + subtuple = {'key': subtuple_key, + 'algo': 'testalgo'} + + with mock.patch('substrapp.tasks.tasks.get_hash') as mget_hash: + mget_hash.return_value = subtuple_key + put_algo(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/'), algo_content) + + self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/algo.py'))) + self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/Dockerfile'))) + + def test_put_algo_zip(self): + filename = 'algo.py' + filepath = os.path.join(self.subtuple_path, filename) + with open(filepath, 'w') as f: + f.write('Hello World') + self.assertTrue(os.path.exists(filepath)) + + zipname = 'sample.zip' + zippath = os.path.join(self.subtuple_path, zipname) + with zipfile.ZipFile(zippath, mode='w') as zf: + zf.write(filepath, arcname=filename) + self.assertTrue(os.path.exists(zippath)) + + subtuple_key = 'testkey' + subtuple = {'key': subtuple_key, 'algo': 'testalgo'} + + with mock.patch('substrapp.tasks.tasks.get_hash') as mget_hash: + with open(zippath, 'rb') as content: + mget_hash.return_value = get_hash(zippath) + put_algo(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/'), content.read()) + + self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/{filename}'))) + + def test_put_metric(self): + filename = 'metrics.py' + filepath = os.path.join(self.subtuple_path, filename) + with open(filepath, 'w') as f: + f.write('Hello World') + self.assertTrue(os.path.exists(filepath)) + + zipname = 'sample.zip' + zippath = os.path.join(self.subtuple_path, zipname) + with zipfile.ZipFile(zippath, mode='w') as zf: + zf.write(filepath, arcname=filename) + self.assertTrue(os.path.exists(zippath)) + + metrics_directory = os.path.join(self.subtuple_path) + create_directory(metrics_directory) + + with mock.patch('substrapp.tasks.tasks.get_hash') as mget_hash: + with open(zippath, 'rb') as content: + mget_hash.return_value = 'hash_value' + put_metric(metrics_directory, content.read()) + + self.assertTrue(os.path.exists(os.path.join(metrics_directory, 'metrics.py'))) + + def test_put_opener(self): + + filepath = os.path.join(self.subtuple_path, self.script_filename) + with open(filepath, 'w') as f: + f.write(self.script.read()) + self.assertTrue(os.path.exists(filepath)) + + opener_hash = get_hash(filepath) + + opener_directory = os.path.join(self.subtuple_path, 'opener') + create_directory(opener_directory) + + with mock.patch('substrapp.models.DataManager.objects.get') as mget: + mget.return_value = FakeDataManager(filepath) + + # test fail + with self.assertRaises(Exception): + put_opener({'dataset': {'openerHash': 'HASH'}}, self.subtuple_path) + + # test work + put_opener({'dataset': {'openerHash': opener_hash}}, self.subtuple_path) + + opener_path = os.path.join(opener_directory, 'opener.py') + self.assertTrue(os.path.exists(opener_path)) + + # test corrupted + + os.remove(opener_path) + shutil.copyfile(filepath, opener_path) + + # Corrupted + with open(opener_path, 'a+') as f: + f.write('corrupted') + + with self.assertRaises(Exception): + put_opener({'dataset': {'openerHash': opener_hash}}, self.subtuple_path) + + def test_put_data_sample_zip(self): + + dir_pkhash, datasamples_path_from_file = store_datasamples_archive(self.data_sample) + + data_sample = DataSample(pkhash=dir_pkhash, path=datasamples_path_from_file) + data_sample.save() + + subtuple = { + 'key': 'bar', + 'dataset': {'keys': [data_sample.pk]} + } + + with mock.patch('substrapp.models.DataSample.objects.get') as mget: + mget.return_value = data_sample + + subtuple_directory = build_subtuple_folders(subtuple) + + put_data_sample(subtuple, subtuple_directory) + + # check folder has been correctly renamed with pk of directory containing uncompressed data sample + self.assertFalse( + os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', 'foo'))) + self.assertTrue( + os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', dir_pkhash))) + + # check subtuple folder has been created and sym links exists + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) + self.assertTrue(os.path.islink(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'LABEL_0024900.csv'))) + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'IMG_0024900.jpg'))) + + def test_put_data_sample_zip_fail(self): + + data_sample = DataSample(pkhash='foo', path=self.data_sample) + data_sample.save() + + subtuple = { + 'key': 'bar', + 'dataset': {'keys': ['fake_pk']} + } + + subtuple2 = { + 'key': 'bar', + 'dataset': {'keys': [data_sample.pk]} + } + + with mock.patch('substrapp.models.DataSample.objects.get') as mget: + mget.return_value = data_sample + + subtuple_directory = build_subtuple_folders(subtuple) + + with self.assertRaises(Exception): + put_data_sample(subtuple, subtuple_directory) + + with self.assertRaises(Exception): + put_data_sample(subtuple2, '/fake/directory/failure') + + def test_put_data_tar(self): + + dir_pkhash, datasamples_path_from_file = store_datasamples_archive(self.data_sample_tar) + + data_sample = DataSample(pkhash=dir_pkhash, path=datasamples_path_from_file) + data_sample.save() + + subtuple = { + 'key': 'bar', + 'dataset': {'keys': [data_sample.pk]} + } + + with mock.patch('substrapp.models.DataSample.objects.get') as mget: + mget.return_value = data_sample + + subtuple_directory = build_subtuple_folders(subtuple) + + put_data_sample(subtuple, subtuple_directory) + + # check folder has been correctly renamed with pk of directory containing uncompressed data_sample + self.assertFalse(os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', 'foo'))) + self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', dir_pkhash))) + + # check subtuple folder has been created and sym links exists + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) + self.assertTrue(os.path.islink(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'LABEL_0024900.csv'))) + self.assertTrue(os.path.exists(os.path.join( + MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'IMG_0024900.jpg'))) + + def test_put_model(self): + + model_content = self.model.read().encode() + + traintupleKey = compute_hash(model_content) + model_hash = compute_hash(model_content, traintupleKey) + model_type = 'model' + subtuple = {'key': model_hash, model_type: {'hash': model_hash, 'traintupleKey': traintupleKey}} + + model_directory = os.path.join(self.subtuple_path, 'model') + create_directory(model_directory) + put_model(subtuple, self.subtuple_path, model_content) + + model_path = os.path.join(model_directory, traintupleKey) + self.assertTrue(os.path.exists(model_path)) + + shutil.copyfile(model_path, model_path + '-local') + + # Corrupted + with open(model_path, 'a+') as f: + f.write('corrupted') + + with mock.patch('substrapp.models.Model.objects.get') as mget: + mget.return_value = FakeModel(model_path + '-local') + with self.assertRaises(Exception): + put_model({'model': {'hash': model_hash, 'traintupleKey': traintupleKey}}, + self.subtuple_path, model_content) + + os.remove(model_path) + + with mock.patch('substrapp.models.Model.objects.get') as mget: + mget.return_value = FakeModel(model_path + '-local') + put_model(subtuple, self.subtuple_path, model_content) + self.assertTrue(os.path.exists(model_path)) + + with mock.patch('substrapp.models.Model.objects.get') as mget: + mget.return_value = FakeModel(model_path) + with self.assertRaises(Exception): + put_model({'model': {'hash': 'fail-hash', 'traintupleKey': traintupleKey}}, + self.subtuple_path, model_content) + + with self.assertRaises(Exception): + put_model(subtuple, self.subtuple_path, None) + + def test_put_models(self): + + model_content = self.model.read().encode() + models_content = [model_content, model_content + b', -2.0'] + + traintupleKey = compute_hash(models_content[0]) + model_hash = compute_hash(models_content[0], traintupleKey) + + traintupleKey2 = compute_hash(models_content[1]) + model_hash2 = compute_hash(models_content[1], traintupleKey2) + + model_path = os.path.join(self.subtuple_path, 'model', traintupleKey) + model_path2 = os.path.join(self.subtuple_path, 'model', traintupleKey2) + + model_type = 'inModels' + subtuple = {model_type: [{'hash': model_hash, 'traintupleKey': traintupleKey}, + {'hash': model_hash2, 'traintupleKey': traintupleKey2}]} + + model_directory = os.path.join(self.subtuple_path, 'model/') + + create_directory(model_directory) + put_models(subtuple, self.subtuple_path, models_content) + + self.assertTrue(os.path.exists(model_path)) + self.assertTrue(os.path.exists(model_path2)) + + os.rename(model_path, model_path + '-local') + os.rename(model_path2, model_path2 + '-local') + + with mock.patch('substrapp.models.Model.objects.get') as mget: + mget.side_effect = [FakeModel(model_path + '-local'), FakeModel(model_path2 + '-local')] + put_models(subtuple, self.subtuple_path, models_content) + + self.assertTrue(os.path.exists(model_path)) + self.assertTrue(os.path.exists(model_path2)) + + with mock.patch('substrapp.models.Model.objects.get') as mget: + mget.return_value = FakeModel(model_path) + with self.assertRaises(Exception): + put_models({'inModels': [{'hash': 'hash'}]}, self.subtuple_path, model_content) + + with self.assertRaises(Exception): + put_models({'model': {'hash': 'fail-hash'}}, self.subtuple_path, None) + + def test_get_model(self): + model_content = self.model.read().encode() + traintupleKey = compute_hash(model_content) + model_hash = compute_hash(model_content, traintupleKey) + model_type = 'model' + subtuple = {model_type: {'hash': model_hash, 'traintupleKey': traintupleKey}} + + with mock.patch('substrapp.tasks.utils.get_remote_file_content') as mget_remote_file, \ + mock.patch('substrapp.tasks.utils.get_owner') as mget_owner,\ + mock.patch('substrapp.tasks.tasks.get_object_from_ledger') as mget_object_from_ledger: + mget_remote_file.return_value = model_content + mget_owner.return_value = assets.traintuple[1]['creator'] + mget_object_from_ledger.return_value = assets.traintuple[1] # uses index 1 to have a set value of outModel + model_content = get_model(subtuple) + + self.assertIsNotNone(model_content) + + self.assertIsNone(get_model({})) + + def test_get_models(self): + model_content = self.model.read().encode() + models_content = [model_content, model_content + b', -2.0'] + + traintupleKey = compute_hash(models_content[0]) + model_hash = compute_hash(models_content[0], traintupleKey) + + traintupleKey2 = compute_hash(models_content[1]) + model_hash2 = compute_hash(models_content[1], traintupleKey2) + + model_type = 'inModels' + subtuple = {model_type: [{'hash': model_hash, 'traintupleKey': traintupleKey}, + {'hash': model_hash2, 'traintupleKey': traintupleKey2}]} + + with mock.patch('substrapp.tasks.utils.get_remote_file_content') as mget_remote_file, \ + mock.patch('substrapp.tasks.utils.authenticate_worker'),\ + mock.patch('substrapp.tasks.tasks.get_object_from_ledger'): + mget_remote_file.side_effect = (models_content[0], models_content[1]) + models_content_res = get_models(subtuple) + + self.assertEqual(models_content_res, models_content) + + self.assertEqual(len(get_models({})), 0) + + def test_get_algo(self): + algo_content = self.algo.read() + algo_hash = get_hash(self.algo) + + subtuple = { + 'algo': { + 'storageAddress': assets.algo[0]['content']['storageAddress'], + 'owner': assets.algo[0]['owner'], + 'hash': algo_hash + } + } + + with mock.patch('substrapp.tasks.utils.get_remote_file_content') as mget_remote_file,\ + mock.patch('substrapp.tasks.utils.get_owner') as get_owner,\ + mock.patch('substrapp.tasks.tasks.get_object_from_ledger') as get_object_from_ledger: + mget_remote_file.return_value = algo_content + get_owner.return_value = 'external_node_id' + get_object_from_ledger.return_value = assets.algo[0] + + data = get_algo(subtuple) + self.assertEqual(algo_content, data) + + def test_get_objective(self): + metrics_content = self.script.read().encode('utf-8') + objective_hash = get_hash(self.script) + + with mock.patch('substrapp.models.Objective.objects.get') as mget: + + mget.return_value = FakeObjective() + + objective = get_objective({'objective': {'hash': objective_hash, + 'metrics': ''}}) + self.assertTrue(isinstance(objective, bytes)) + self.assertEqual(objective, b'foo') + + with mock.patch('substrapp.tasks.utils.get_remote_file_content') as mget_remote_file, \ + mock.patch('substrapp.tasks.tasks.get_object_from_ledger'), \ + mock.patch('substrapp.tasks.utils.authenticate_worker'),\ + mock.patch('substrapp.models.Objective.objects.update_or_create') as mupdate_or_create: + + mget.return_value = FakeObjective() + mget_remote_file.return_value = metrics_content + mupdate_or_create.return_value = FakeObjective(), True + + objective = get_objective({'objective': {'hash': objective_hash, + 'metrics': ''}}) + self.assertTrue(isinstance(objective, bytes)) + self.assertEqual(objective, b'foo') + + def test_compute_docker(self): + cpu_set, gpu_set = None, None + client = docker.from_env() + + dockerfile_path = os.path.join(self.subtuple_path, 'Dockerfile') + with open(dockerfile_path, 'w') as f: + f.write('FROM library/hello-world') + + hash_docker = uuid.uuid4().hex + compute_docker(client, self.ResourcesManager, + self.subtuple_path, 'test_compute_docker_' + hash_docker, + 'test_compute_docker_name_' + hash_docker, None, None) + + self.assertIsNone(cpu_set) + self.assertIsNone(gpu_set) + + def test_build_subtuple_folders(self): + with mock.patch('substrapp.tasks.tasks.getattr') as getattr: + getattr.return_value = self.subtuple_path + + subtuple_key = 'test1234' + subtuple = {'key': subtuple_key} + subtuple_directory = build_subtuple_folders(subtuple) + + self.assertTrue(os.path.exists(subtuple_directory)) + self.assertEqual(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}'), subtuple_directory) + + for root, dirs, files in os.walk(subtuple_directory): + nb_subfolders = len(dirs) + + self.assertTrue(5, nb_subfolders) + + @override_settings( + task_eager_propagates=True, + task_always_eager=True, + broker_url='memory://', + backend='memory' + ) + def test_prepare_tasks(self): + + class FakeSettings(object): + def __init__(self): + self.LEDGER = {'signcert': 'signcert', + 'org': 'owkin', + 'peer': 'peer'} + + self.MEDIA_ROOT = MEDIA_ROOT + + subtuple = [{'key': 'subtuple_test', 'computePlanID': 'flkey'}] + + with mock.patch('substrapp.tasks.tasks.settings') as msettings, \ + mock.patch.object(TaskResult.objects, 'filter') as mtaskresult, \ + mock.patch('substrapp.tasks.tasks.get_hash') as mget_hash, \ + mock.patch('substrapp.tasks.tasks.query_tuples') as mquery_tuples, \ + mock.patch('substrapp.tasks.tasks.get_objective') as mget_objective, \ + mock.patch('substrapp.tasks.tasks.get_algo') as mget_algo, \ + mock.patch('substrapp.tasks.tasks.get_model') as mget_model, \ + mock.patch('substrapp.tasks.tasks.build_subtuple_folders') as mbuild_subtuple_folders, \ + mock.patch('substrapp.tasks.tasks.put_opener') as mput_opener, \ + mock.patch('substrapp.tasks.tasks.put_data_sample') as mput_data_sample, \ + mock.patch('substrapp.tasks.tasks.put_metric') as mput_metric, \ + mock.patch('substrapp.tasks.tasks.put_algo') as mput_algo, \ + mock.patch('substrapp.tasks.tasks.json.loads') as mjson_loads, \ + mock.patch('substrapp.tasks.tasks.AsyncResult') as masyncres, \ + mock.patch('substrapp.tasks.tasks.put_model') as mput_model, \ + mock.patch('substrapp.tasks.tasks.get_owner') as get_owner: + + msettings.return_value = FakeSettings() + mget_hash.return_value = 'owkinhash' + mquery_tuples.return_value = subtuple + mget_objective.return_value = 'objective' + mget_algo.return_value = 'algo', 'algo_hash' + mget_model.return_value = 'model', 'model_hash' + mbuild_subtuple_folders.return_value = MEDIA_ROOT + mput_opener.return_value = 'opener' + mput_data_sample.return_value = 'data' + mput_metric.return_value = 'metric' + mput_algo.return_value = 'algo' + mput_model.return_value = 'model' + get_owner.return_value = 'foo' + + masyncres.return_value.state = 'PENDING' + + mock_filter = MagicMock() + mock_filter.count.return_value = 1 + mtaskresult.return_value = mock_filter + + mjson_loads.return_value = {'worker': 'worker'} + + with mock.patch('substrapp.tasks.tasks.log_start_tuple') as mlog_start_tuple: + mlog_start_tuple.side_effect = LedgerStatusError('Bad Response') + prepare_task('traintuple') + + with mock.patch('substrapp.tasks.tasks.log_start_tuple') as mlog_start_tuple, \ + mock.patch('substrapp.tasks.tasks.compute_task.apply_async') as mapply_async: + mlog_start_tuple.return_value = 'data', 201 + mapply_async.return_value = 'do_task' + prepare_task('traintuple') + + def test_do_task(self): + + class FakeSettings(object): + def __init__(self): + self.LEDGER = {'signcert': 'signcert', + 'org': 'owkin', + 'peer': 'peer'} + + self.MEDIA_ROOT = MEDIA_ROOT + + subtuple_key = 'test_owkin' + subtuple = {'key': subtuple_key, 'inModels': None} + subtuple_directory = build_subtuple_folders(subtuple) + + with mock.patch('substrapp.tasks.tasks.settings') as msettings, \ + mock.patch('substrapp.tasks.tasks.getattr') as mgetattr: + msettings.return_value = FakeSettings() + mgetattr.return_value = self.subtuple_path + + for name in ['opener', 'metrics']: + with open(os.path.join(subtuple_directory, f'{name}/{name}.py'), 'w') as f: + f.write('Hello World') + + perf = 0.3141592 + with open(os.path.join(subtuple_directory, 'pred/perf.json'), 'w') as f: + f.write(f'{{"all": {perf}}}') + + with open(os.path.join(subtuple_directory, 'model/model'), 'w') as f: + f.write("MODEL") + + with mock.patch('substrapp.tasks.tasks.compute_docker') as mcompute_docker: + mcompute_docker.return_value = 'DONE' + do_task(subtuple, 'traintuple') + + def test_compute_task(self): + + class FakeSettings(object): + def __init__(self): + self.LEDGER = {'signcert': 'signcert', + 'org': 'owkin', + 'peer': 'peer'} + + self.MEDIA_ROOT = MEDIA_ROOT + + subtuple_key = 'test_owkin' + subtuple = {'key': subtuple_key, 'inModels': None} + subtuple_directory = build_subtuple_folders(subtuple) + + with mock.patch('substrapp.tasks.tasks.settings') as msettings, \ + mock.patch('substrapp.tasks.tasks.getattr') as mgetattr, \ + mock.patch('substrapp.tasks.tasks.log_start_tuple') as mlog_start_tuple: + msettings.return_value = FakeSettings() + mgetattr.return_value = self.subtuple_path + mlog_start_tuple.return_value = 'data', 200 + + for name in ['opener', 'metrics']: + with open(os.path.join(subtuple_directory, f'{name}/{name}.py'), 'w') as f: + f.write('Hello World') + + perf = 0.3141592 + with open(os.path.join(subtuple_directory, 'pred/perf.json'), 'w') as f: + f.write(f'{{"all": {perf}}}') + + with open(os.path.join(subtuple_directory, 'model/model'), 'w') as f: + f.write("MODEL") + + with mock.patch('substrapp.tasks.tasks.compute_docker') as mcompute_docker, \ + mock.patch('substrapp.tasks.tasks.do_task') as mdo_task,\ + mock.patch('substrapp.tasks.tasks.prepare_materials') as mprepare_materials, \ + mock.patch('substrapp.tasks.tasks.log_success_tuple') as mlog_success_tuple: + + mcompute_docker.return_value = 'DONE' + mprepare_materials.return_value = 'DONE' + mdo_task.return_value = 'DONE' + + mlog_success_tuple.return_value = 'data', 201 + compute_task('traintuple', subtuple, None) + + mlog_success_tuple.return_value = 'data', 404 + compute_task('traintuple', subtuple, None) + + with mock.patch('substrapp.tasks.tasks.log_fail_tuple') as mlog_fail_tuple: + mdo_task.side_effect = Exception("Test") + mlog_fail_tuple.return_value = 'data', 404 + compute_task('traintuple', subtuple, None) + + def test_prepare_materials(self): + + class FakeSettings(object): + def __init__(self): + self.LEDGER = {'signcert': 'signcert', + 'org': 'owkin', + 'peer': 'peer'} + + self.MEDIA_ROOT = MEDIA_ROOT + + subtuple = [{'key': 'subtuple_test', 'computePlanID': 'flkey'}] + + with mock.patch('substrapp.tasks.tasks.settings') as msettings, \ + mock.patch('substrapp.tasks.tasks.get_hash') as mget_hash, \ + mock.patch('substrapp.tasks.tasks.query_tuples') as mquery_tuples, \ + mock.patch('substrapp.tasks.tasks.get_objective') as mget_objective, \ + mock.patch('substrapp.tasks.tasks.get_algo') as mget_algo, \ + mock.patch('substrapp.tasks.tasks.get_model') as mget_model, \ + mock.patch('substrapp.tasks.tasks.build_subtuple_folders') as mbuild_subtuple_folders, \ + mock.patch('substrapp.tasks.tasks.put_opener') as mput_opener, \ + mock.patch('substrapp.tasks.tasks.put_data_sample') as mput_data_sample, \ + mock.patch('substrapp.tasks.tasks.put_metric') as mput_metric, \ + mock.patch('substrapp.tasks.tasks.put_algo') as mput_algo, \ + mock.patch('substrapp.tasks.tasks.put_model') as mput_model: + + msettings.return_value = FakeSettings() + mget_hash.return_value = 'owkinhash' + mquery_tuples.return_value = subtuple, 200 + mget_objective.return_value = 'objective' + mget_algo.return_value = 'algo', 'algo_hash' + mget_model.return_value = 'model', 'model_hash' + mbuild_subtuple_folders.return_value = MEDIA_ROOT + mput_opener.return_value = 'opener' + mput_data_sample.return_value = 'data' + mput_metric.return_value = 'metric' + mput_algo.return_value = 'algo' + mput_model.return_value = 'model' + + prepare_materials(subtuple[0], 'traintuple') + prepare_materials(subtuple[0], 'testtuple') diff --git a/backend/substrapp/tests/views/__init__.py b/backend/substrapp/tests/views/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/substrapp/tests/views/test_views_authentication.py b/backend/substrapp/tests/views/test_views_authentication.py new file mode 100644 index 000000000..ae6d475fc --- /dev/null +++ b/backend/substrapp/tests/views/test_views_authentication.py @@ -0,0 +1,91 @@ +import mock +from django.urls import reverse +import os +import shutil +from rest_framework import status +from rest_framework.test import APITestCase +from node.models import IncomingNode, OutgoingNode +from substrapp.models import Algo + +from ..common import generate_basic_auth_header, get_sample_algo_metadata, get_sample_algo, get_description_algo +from django.test import override_settings + +MEDIA_ROOT = "/tmp/unittests_views/" + + +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +class AuthenticationTests(APITestCase): + def setUp(self): + + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + # create algo instance which file download is protected + self.algo_file, self.algo_filename = get_sample_algo() + self.algo_description_file, self.algo_description_filename = get_description_algo() + self.algo = Algo.objects.create(file=self.algo_file, description=self.algo_description_file) + self.algo_url = reverse('substrapp:algo-file', kwargs={'pk': self.algo.pk}) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + @classmethod + def setUpTestData(cls): + cls.incoming_node = IncomingNode.objects.create(node_id="external_node_id", secret="s3cr37") + cls.outgoing_node = OutgoingNode.objects.create(node_id="external_node_id", secret="s3cr37") + + def test_authentication_fail(self): + response = self.client.get(self.algo_url, **self.extra) + + self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) + + def test_authentication_internal(self): + authorization_header = generate_basic_auth_header(self.outgoing_node.node_id, self.outgoing_node.secret) + + self.client.credentials(HTTP_AUTHORIZATION=authorization_header) + + with mock.patch('substrapp.views.utils.get_owner', return_value='foo'), \ + mock.patch('substrapp.views.utils.get_object_from_ledger') \ + as mget_object_from_ledger: + mget_object_from_ledger.return_value = get_sample_algo_metadata() + response = self.client.get(self.algo_url, **self.extra) + + self.assertEqual(status.HTTP_200_OK, response.status_code) + + def test_authentication_with_bad_settings_credentials_fail(self): + authorization_header = generate_basic_auth_header('unauthorized_username', 'unauthorized_password') + + self.client.credentials(HTTP_AUTHORIZATION=authorization_header) + response = self.client.get(self.algo_url, **self.extra) + + self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) + + def test_authentication_with_node(self): + authorization_header = generate_basic_auth_header('external_node_id', 's3cr37') + + self.client.credentials(HTTP_AUTHORIZATION=authorization_header) + + with mock.patch('substrapp.views.utils.get_owner', return_value='foo'), \ + mock.patch('substrapp.views.utils.get_object_from_ledger') \ + as mget_object_from_ledger: + mget_object_from_ledger.return_value = get_sample_algo_metadata() + response = self.client.get(self.algo_url, **self.extra) + + self.assertEqual(status.HTTP_200_OK, response.status_code) + + def test_authentication_with_node_fail(self): + bad_authorization_headers = [ + generate_basic_auth_header('external_node_id', 'bad_s3cr37'), + generate_basic_auth_header('bad_external_node_id', 's3cr37'), + generate_basic_auth_header('bad_external_node_id', 'bad_s3cr37'), + ] + + for header in bad_authorization_headers: + self.client.credentials(HTTP_AUTHORIZATION=header) + response = self.client.get(self.algo_url, **self.extra) + + self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) diff --git a/backend/substrapp/tests/views/tests_utils.py b/backend/substrapp/tests/views/tests_utils.py new file mode 100644 index 000000000..b4d265641 --- /dev/null +++ b/backend/substrapp/tests/views/tests_utils.py @@ -0,0 +1,131 @@ +import functools +import os +import tempfile + +import mock +import requests +from requests.auth import HTTPBasicAuth +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.views.utils import PermissionMixin + + +class MockRequest: + user = None + + +def with_permission_mixin(remote, same_file_property, has_access): + def inner(f): + @functools.wraps(f) + def wrapper(self): + ledger_value = { + 'owner': 'owner-foo', + 'file_property' if same_file_property else 'ledger_file_property': { + 'storageAddress': 'foo' + } + } + with mock.patch('substrapp.views.utils.get_object_from_ledger', + return_value=ledger_value), \ + tempfile.NamedTemporaryFile() as tmp_file, \ + mock.patch('substrapp.views.utils.get_owner', + return_value='not-owner-foo' if remote else 'owner-foo'): + tmp_file_content = b'foo bar' + tmp_file.write(tmp_file_content) + tmp_file.seek(0) + + class TestFieldFile: + path = tmp_file.name + + class TestModel: + file_property = TestFieldFile() + + permission_mixin = PermissionMixin() + permission_mixin.get_object = mock.MagicMock(return_value=TestModel()) + permission_mixin._has_access = mock.MagicMock(return_value=has_access) + permission_mixin.lookup_url_kwarg = 'foo' + permission_mixin.kwargs = {'foo': 'bar'} + permission_mixin.ledger_query_call = 'foo' + + kwargs = { + 'tmp_file': tmp_file, + 'content': tmp_file_content, + 'filename': os.path.basename(tmp_file.name) + } + + f(self, permission_mixin, **kwargs) + return wrapper + return inner + + +def with_requests_mock(allowed): + def inner(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + tmp_file = kwargs['tmp_file'] + filename = kwargs['filename'] + + requests_response = requests.Response() + if allowed: + requests_response.raw = tmp_file + requests_response.headers['Content-Disposition'] = f'attachment; filename="{filename}"' + requests_response.status_code = status.HTTP_200_OK + else: + requests_response._content = b'{"message": "nope"}' + requests_response.status_code = status.HTTP_401_UNAUTHORIZED + + kwargs['requests_response'] = requests_response + + with mock.patch('substrapp.views.utils.authenticate_outgoing_request', + return_value=HTTPBasicAuth('foo', 'bar')), \ + mock.patch('substrapp.utils.requests.get', return_value=requests_response): + f(*args, **kwargs) + return wrapper + return inner + + +class PermissionMixinDownloadFileTests(APITestCase): + @with_permission_mixin(remote=False, same_file_property=False, has_access=True) + def test_download_file_local_allowed(self, permission_mixin, content, filename, **kwargs): + res = permission_mixin.download_file(MockRequest(), + 'file_property', + 'ledger_file_property') + res_content = b''.join(list(res.streaming_content)) + self.assertEqual(res_content, content) + self.assertEqual(res['Content-Disposition'], f'attachment; filename="{filename}"') + self.assertTrue(permission_mixin.get_object.called) + + @with_permission_mixin(remote=False, same_file_property=True, has_access=False) + def test_download_file_local_denied(self, permission_mixin, **kwargs): + res = permission_mixin.download_file(MockRequest(), 'file_property') + self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN) + + @with_permission_mixin(remote=True, same_file_property=False, has_access=True) + @with_requests_mock(allowed=True) + def test_download_file_remote_allowed(self, permission_mixin, content, filename, **kwargs): + res = permission_mixin.download_file(MockRequest(), + 'file_property', + 'ledger_file_property') + res_content = b''.join(list(res.streaming_content)) + self.assertEqual(res_content, content) + self.assertEqual(res['Content-Disposition'], f'attachment; filename="{filename}"') + self.assertFalse(permission_mixin.get_object.called) + + @with_permission_mixin(remote=True, same_file_property=False, has_access=True) + @with_requests_mock(allowed=False) + def test_download_file_remote_denied(self, permission_mixin, **kwargs): + res = permission_mixin.download_file(MockRequest(), + 'file_property', + 'ledger_file_property') + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertFalse(permission_mixin.get_object.called) + + @with_permission_mixin(remote=True, same_file_property=True, has_access=True) + @with_requests_mock(allowed=True) + def test_download_file_remote_same_file_property(self, permission_mixin, content, filename, + **kwargs): + res = permission_mixin.download_file(MockRequest(), 'file_property') + res_content = b''.join(list(res.streaming_content)) + self.assertEqual(res_content, content) + self.assertEqual(res['Content-Disposition'], f'attachment; filename="{filename}"') + self.assertFalse(permission_mixin.get_object.called) diff --git a/backend/substrapp/tests/views/tests_views.py b/backend/substrapp/tests/views/tests_views.py new file mode 100644 index 000000000..ce1d624e1 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views.py @@ -0,0 +1,28 @@ +import mock + +from rest_framework.test import APITestCase + +from substrapp.views.datasample import path_leaf +from substrapp.ledger_utils import get_object_from_ledger + + +from ..assets import objective +from ..common import AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +class ViewTests(APITestCase): + client_class = AuthenticatedClient + + def test_data_sample_path_view(self): + self.assertEqual('tutu', path_leaf('/toto/tata/tutu')) + self.assertEqual('toto', path_leaf('/toto/')) + + def test_utils_get_object_from_ledger(self): + + with mock.patch('substrapp.ledger_utils.query_ledger') as mquery_ledger: + mquery_ledger.return_value = objective + data = get_object_from_ledger('', 'queryObjective') + + self.assertEqual(data, objective) diff --git a/backend/substrapp/tests/views/tests_views_algo.py b/backend/substrapp/tests/views/tests_views_algo.py new file mode 100644 index 000000000..adaa9abe6 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_algo.py @@ -0,0 +1,285 @@ +import copy +import os +import shutil +import logging + +import mock +import urllib.parse + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + + +from substrapp.serializers import LedgerAlgoSerializer + +from substrapp.ledger_utils import LedgerError + +from substrapp.utils import get_hash + +from ..common import get_sample_algo, AuthenticatedClient +from ..assets import objective, datamanager, algo, traintuple, model + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class AlgoViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.algo, self.algo_filename = get_sample_algo() + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_algo_list_empty(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + def test_algo_list_success(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger: + mquery_ledger.return_value = algo + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [algo]) + + def test_algo_list_filter_fail(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger: + mquery_ledger.return_value = algo + + search_params = '?search=algERRORo' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertIn('Malformed search filters', r['message']) + + def test_algo_list_filter_name(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger: + mquery_ledger.return_value = algo + + search_params = '?search=algo%253Aname%253ALogistic%2520regression' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_algo_list_filter_dual(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger: + mquery_ledger.return_value = algo + + search_params = f'?search=algo%253Aname%253A{urllib.parse.quote(algo[2]["name"])}' + search_params += f'%2Calgo%253Aowner%253A{algo[2]["owner"]}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_algo_list_filter_datamanager_fail(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = algo + mquery_ledger2.return_value = datamanager + + search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertIn('Malformed search filters', r['message']) + + def test_algo_list_filter_objective_fail(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = algo + mquery_ledger2.return_value = objective + + search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertIn('Malformed search filters', r['message']) + + def test_algo_list_filter_model(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = algo + mquery_ledger2.return_value = traintuple + + pkhash = model[1]['traintuple']['outModel']['hash'] + search_params = f'?search=model%253Ahash%253A{pkhash}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_algo_retrieve(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + algo_hash = get_hash(os.path.join(dir_path, '../../../../fixtures/chunantes/algos/algo4/algo.tar.gz')) + url = reverse('substrapp:algo-list') + algo_response = [a for a in algo if a['key'] == algo_hash][0] + with mock.patch('substrapp.views.algo.get_object_from_ledger') as mget_object_from_ledger, \ + mock.patch('substrapp.views.algo.get_remote_asset') as get_remote_asset: + + with open(os.path.join(dir_path, + '../../../../fixtures/chunantes/algos/algo4/description.md'), 'rb') as f: + content = f.read() + mget_object_from_ledger.return_value = algo_response + get_remote_asset.return_value = content + + search_params = f'{algo_hash}/' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(r, algo_response) + + def test_algo_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + url = reverse('substrapp:algo-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.algo.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('TEST') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_algo_create(self): + url = reverse('substrapp:algo-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + algo_path = os.path.join(dir_path, '../../../../fixtures/chunantes/algos/algo3/algo.tar.gz') + description_path = os.path.join(dir_path, '../../../../fixtures/chunantes/algos/algo3/description.md') + + pkhash = get_hash(algo_path) + + data = {'name': 'Logistic regression', + 'file': open(algo_path, 'rb'), + 'description': open(description_path, 'rb'), + 'objective_key': get_hash(os.path.join( + dir_path, '../../../../fixtures/chunantes/objectives/objective0/description.md')), + 'permissions_public': True, + 'permissions_authorized_ids': []} + + with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: + + mcreate.return_value = {} + + response = self.client.post(url, data=data, format='multipart', **self.extra) + + self.assertEqual(response.data['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + data['description'].close() + data['file'].close() + + def test_algo_list_storage_addresses_update(self): + url = reverse('substrapp:algo-list') + with mock.patch('substrapp.views.algo.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.objective.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_algos = copy.deepcopy(algo) + for ledger_algo in ledger_algos: + for field in ('description', 'content'): + ledger_algo[field]['storageAddress'] = \ + ledger_algo[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_algos + + # actual test + res = self.client.get(url, **self.extra) + res_algos = res.data[0] + self.assertEqual(len(res_algos), len(algo)) + for i, res_algo in enumerate(res_algos): + for field in ('description', 'content'): + self.assertEqual(res_algo[field]['storageAddress'], + algo[i][field]['storageAddress']) + + def test_algo_retrieve_storage_addresses_update_with_cache(self): + url = reverse('substrapp:algo-detail', args=[algo[0]['key']]) + with mock.patch('substrapp.views.algo.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.algo.node_has_process_permission', + return_value=True), \ + mock.patch('substrapp.views.algo.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_algo = copy.deepcopy(algo[0]) + for field in ('description', 'content'): + ledger_algo[field]['storageAddress'] = \ + ledger_algo[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_algo + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'content'): + self.assertEqual(res.data[field]['storageAddress'], + algo[0][field]['storageAddress']) + + def test_algo_retrieve_storage_addresses_update_without_cache(self): + url = reverse('substrapp:algo-detail', args=[algo[0]['key']]) + with mock.patch('substrapp.views.algo.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.algo.node_has_process_permission', + return_value=False), \ + mock.patch('substrapp.views.algo.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_algo = copy.deepcopy(algo[0]) + for field in ('description', 'content'): + ledger_algo[field]['storageAddress'] = \ + ledger_algo[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_algo + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'content'): + self.assertEqual(res.data[field]['storageAddress'], + algo[0][field]['storageAddress']) diff --git a/backend/substrapp/tests/views/tests_views_computeplan.py b/backend/substrapp/tests/views/tests_views_computeplan.py new file mode 100644 index 000000000..352351683 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_computeplan.py @@ -0,0 +1,63 @@ +import os +import shutil + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.serializers import LedgerComputePlanSerializer +from ..common import AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class ComputePlanViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + def test_create(self): + url = reverse('substrapp:compute_plan-list') + + dummy_key = 'x' * 64 + + data = { + 'algo_key': dummy_key, + 'objective_key': dummy_key, + 'traintuples': [{ + 'data_manager_key': dummy_key, + 'train_data_sample_keys': [dummy_key], + 'traintuple_id': dummy_key, + }], + 'testtuples': [{ + 'traintuple_id': dummy_key, + 'data_manager_key': dummy_key, + }], + } + + with mock.patch.object(LedgerComputePlanSerializer, 'create') as mcreate: + with mock.patch('substrapp.views.computeplan.query_ledger') as mquery: + mcreate.return_value = {} + mquery.return_value = {} + + response = self.client.post(url, data=data, format='json', **self.extra) + + self.assertEqual(response.json(), {}) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) diff --git a/backend/substrapp/tests/views/tests_views_datamanager.py b/backend/substrapp/tests/views/tests_views_datamanager.py new file mode 100644 index 000000000..b55d78de6 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_datamanager.py @@ -0,0 +1,232 @@ +import copy +import os +import shutil +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.ledger_utils import LedgerError +from substrapp.utils import get_hash + + +from ..common import get_sample_datamanager, AuthenticatedClient +from ..assets import objective, datamanager, traintuple, model + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(LEDGER_SYNC_ENABLED=True) +class DataManagerViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.data_description, self.data_description_filename, \ + self.data_data_opener, self.data_opener_filename = get_sample_datamanager() + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_datamanager_list_empty(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + def test_datamanager_list_success(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger: + mquery_ledger.return_value = datamanager + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [datamanager]) + + def test_datamanager_list_filter_fail(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger: + mquery_ledger.return_value = datamanager + + search_params = '?search=dataseERRORt' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertIn('Malformed search filters', r['message']) + + def test_datamanager_list_filter_name(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger: + mquery_ledger.return_value = datamanager + + search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_datamanager_list_filter_objective(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = datamanager + mquery_ledger2.return_value = objective + + search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 2) + + def test_datamanager_list_filter_model(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = datamanager + mquery_ledger2.return_value = traintuple + pkhash = model[1]['traintuple']['outModel']['hash'] + search_params = f'?search=model%253Ahash%253A{pkhash}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 2) + + def test_datamanager_retrieve(self): + url = reverse('substrapp:data_manager-list') + datamanager_response = [d for d in datamanager + if d['key'] == '8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca'][0] + with mock.patch('substrapp.views.datamanager.get_object_from_ledger') as mget_object_from_ledger, \ + mock.patch('substrapp.views.datamanager.get_remote_asset') as mget_remote_asset: + mget_object_from_ledger.return_value = datamanager_response + + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), + '../../../../fixtures/chunantes/datamanagers/datamanager0/opener.py'), 'rb') as f: + opener_content = f.read() + + with open(os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '../../../../fixtures/chunantes/datamanagers/datamanager0/description.md'), 'rb') as f: + description_content = f.read() + + mget_remote_asset.side_effect = [opener_content, description_content] + + search_params = '8dd01465003a9b1e01c99c904d86aa518b3a5dd9dc8d40fe7d075c726ac073ca/' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(r, datamanager_response) + + def test_datamanager_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + url = reverse('substrapp:data_manager-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.datamanager.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('TEST') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_datamanager_list_storage_addresses_update(self): + url = reverse('substrapp:data_manager-list') + with mock.patch('substrapp.views.datamanager.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.datamanager.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_datamanagers = copy.deepcopy(datamanager) + for ledger_datamanager in ledger_datamanagers: + for field in ('description', 'opener'): + ledger_datamanager[field]['storageAddress'] = \ + ledger_datamanager[field]['storageAddress'] \ + .replace('http://testserver', 'http://remotetestserver') + mquery_ledger.return_value = ledger_datamanagers + + # actual test + res = self.client.get(url, **self.extra) + res_datamanagers = res.data[0] + self.assertEqual(len(res_datamanagers), len(datamanager)) + for i, res_datamanager in enumerate(res_datamanagers): + for field in ('description', 'opener'): + self.assertEqual(res_datamanager[field]['storageAddress'], + datamanager[i][field]['storageAddress']) + + def test_datamanager_retrieve_storage_addresses_update_with_cache(self): + url = reverse('substrapp:data_manager-detail', args=[datamanager[0]['key']]) + with mock.patch('substrapp.views.datamanager.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.datamanager.node_has_process_permission', + return_value=True), \ + mock.patch('substrapp.views.datamanager.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_datamanager = copy.deepcopy(datamanager[0]) + for field in ('description', 'opener'): + ledger_datamanager[field]['storageAddress'] = \ + ledger_datamanager[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_datamanager + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'opener'): + self.assertEqual(res.data[field]['storageAddress'], + datamanager[0][field]['storageAddress']) + + def test_datamanager_retrieve_storage_addresses_update_without_cache(self): + url = reverse('substrapp:data_manager-detail', args=[datamanager[0]['key']]) + with mock.patch('substrapp.views.datamanager.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.datamanager.node_has_process_permission', + return_value=False), \ + mock.patch('substrapp.views.datamanager.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_datamanager = copy.deepcopy(datamanager[0]) + for field in ('description', 'opener'): + ledger_datamanager[field]['storageAddress'] = \ + ledger_datamanager[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_datamanager + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'opener'): + self.assertEqual(res.data[field]['storageAddress'], + datamanager[0][field]['storageAddress']) diff --git a/backend/substrapp/tests/views/tests_views_datasample.py b/backend/substrapp/tests/views/tests_views_datasample.py new file mode 100644 index 000000000..487412448 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_datasample.py @@ -0,0 +1,199 @@ +import os +import shutil +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + + +from substrapp.serializers import LedgerDataSampleSerializer + +from substrapp.views.datasample import path_leaf +from substrapp.utils import get_hash, uncompress_content + +from substrapp.models import DataManager + +from ..common import get_sample_datamanager, FakeFilterDataManager, AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(DEFAULT_DOMAIN='https://localhost') +@override_settings(LEDGER_SYNC_ENABLED=True) +class DataSampleViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.data_description, self.data_description_filename, \ + self.data_data_opener, self.data_opener_filename = get_sample_datamanager() + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_data_create_bulk(self): + url = reverse('substrapp:data_sample-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + data_path1 = os.path.join(dir_path, '../../../../fixtures/chunantes/datasamples/datasample1/0024700.zip') + data_path2 = os.path.join(dir_path, '../../../../fixtures/chunantes/datasamples/datasample0/0024899.zip') + + # dir hash + pkhash1 = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' + pkhash2 = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' + + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] + + data = { + 'files': [path_leaf(data_path1), path_leaf(data_path2)], + path_leaf(data_path1): open(data_path1, 'rb'), + path_leaf(data_path2): open(data_path2, 'rb'), + 'data_manager_keys': data_manager_keys, + 'test_only': False + } + + with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + + mdatamanager.return_value = FakeFilterDataManager(1) + mcreate.return_value = {'keys': [pkhash1, pkhash2]} + response = self.client.post(url, data=data, format='multipart', **self.extra) + self.assertEqual([r['pkhash'] for r in response.data], [pkhash1, pkhash2]) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + for x in data['files']: + data[x].close() + + def test_data_create(self): + url = reverse('substrapp:data_sample-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + data_path = os.path.join(dir_path, '../../../../fixtures/chunantes/datasamples/datasample1/0024700.zip') + + # dir hash + pkhash = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' + + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] + + data = { + 'file': open(data_path, 'rb'), + 'data_manager_keys': data_manager_keys, + 'test_only': False + } + + with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + + mdatamanager.return_value = FakeFilterDataManager(1) + mcreate.return_value = {'keys': [pkhash]} + response = self.client.post(url, data=data, format='multipart', **self.extra) + + self.assertEqual(response.data[0]['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + data['file'].close() + + def test_data_create_parent_path(self): + url = reverse('substrapp:data_sample-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + data_zip_path = os.path.join(dir_path, '../../../../fixtures/chunantes/datasamples/datasample1/0024700.zip') + data_parent_path = os.path.join(MEDIA_ROOT, 'data_samples') + data_path = os.path.join(data_parent_path, '0024700') + + with open(data_zip_path, 'rb') as data_zip: + uncompress_content(data_zip.read(), data_path) + + # dir hash + pkhash = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' + + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] + + data = { + 'path': data_parent_path, + 'data_manager_keys': data_manager_keys, + 'test_only': False, + 'multiple': True, + } + + with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + + mdatamanager.return_value = FakeFilterDataManager(1) + mcreate.return_value = {'keys': [pkhash]} + response = self.client.post(url, data=data, format='multipart', **self.extra) + + self.assertEqual(response.data[0]['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_data_create_path(self): + url = reverse('substrapp:data_sample-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + data_zip_path = os.path.join(dir_path, '../../../../fixtures/chunantes/datasamples/datasample1/0024700.zip') + data_path = os.path.join(MEDIA_ROOT, '0024700') + + with open(data_zip_path, 'rb') as data_zip: + uncompress_content(data_zip.read(), data_path) + + # dir hash + pkhash = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' + + data_manager_keys = [ + get_hash(os.path.join(dir_path, '../../../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] + + data = { + 'path': data_path, + 'data_manager_keys': data_manager_keys, + 'test_only': False + } + + with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ + mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: + + mdatamanager.return_value = FakeFilterDataManager(1) + mcreate.return_value = {'keys': [pkhash]} + response = self.client.post(url, data=data, format='multipart', **self.extra) + + self.assertEqual(response.data[0]['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_datasamples_list(self): + url = reverse('substrapp:data_sample-list') + with mock.patch('substrapp.views.datasample.query_ledger') as mquery_ledger: + mquery_ledger.side_effect = [None, ['DataSampleA', 'DataSampleB']] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, []) + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, ['DataSampleA', 'DataSampleB']) diff --git a/backend/substrapp/tests/views/tests_views_model.py b/backend/substrapp/tests/views/tests_views_model.py new file mode 100644 index 000000000..c6f5f8e57 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_model.py @@ -0,0 +1,161 @@ +import os +import shutil +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.ledger_utils import LedgerError + +from substrapp.utils import get_hash + +from ..common import get_sample_model, AuthenticatedClient +from ..assets import objective, datamanager, algo, model + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class ModelViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.model, self.model_filename = get_sample_model() + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_model_list_empty(self): + url = reverse('substrapp:model-list') + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger: + mquery_ledger.side_effect = [[], ['ISIC']] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [['ISIC']]) + + def test_model_list_filter_fail(self): + + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger: + mquery_ledger.return_value = model + + url = reverse('substrapp:model-list') + search_params = '?search=modeERRORl' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + self.assertIn('Malformed search filters', r['message']) + + def test_model_list_filter_hash(self): + + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger: + mquery_ledger.return_value = model + + pkhash = model[1]['traintuple']['outModel']['hash'] + url = reverse('substrapp:model-list') + search_params = f'?search=model%253Ahash%253A{pkhash}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + self.assertEqual(len(r[0]), 1) + + def test_model_list_filter_datamanager(self): + url = reverse('substrapp:model-list') + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = model + mquery_ledger2.return_value = datamanager + + search_params = '?search=dataset%253Aname%253AISIC%25202018' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 4) + + def test_model_list_filter_objective(self): + url = reverse('substrapp:model-list') + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = model + mquery_ledger2.return_value = objective + + search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 4) + + def test_model_list_filter_algo(self): + url = reverse('substrapp:model-list') + with mock.patch('substrapp.views.model.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = model + mquery_ledger2.return_value = algo + + search_params = '?search=algo%253Aname%253ALogistic%2520regression' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 2) + + def test_model_retrieve(self): + + with mock.patch('substrapp.views.model.get_object_from_ledger') as mget_object_from_ledger, \ + mock.patch('substrapp.views.model.get_remote_asset') as get_remote_asset: + mget_object_from_ledger.return_value = model[1] + + get_remote_asset.return_value = self.model.read().encode() + + url = reverse('substrapp:model-list') + search_params = model[1]['traintuple']['outModel']['hash'] + '/' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + self.assertEqual(r, model[1]) + + def test_model_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + url = reverse('substrapp:model-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.model.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('TEST') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/backend/substrapp/tests/views/tests_views_objective.py b/backend/substrapp/tests/views/tests_views_objective.py new file mode 100644 index 000000000..000451d9a --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_objective.py @@ -0,0 +1,325 @@ +import os +import shutil +import logging +import zipfile +import copy + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.serializers import LedgerObjectiveSerializer + +from substrapp.ledger_utils import LedgerError + +from substrapp.utils import compute_hash, get_hash + +from ..common import get_sample_objective, AuthenticatedClient +from ..assets import objective, datamanager, traintuple, model + +MEDIA_ROOT = "/tmp/unittests_views/" + + +def zip_folder(path, destination): + zipf = zipfile.ZipFile(destination, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(path): + for f in files: + abspath = os.path.join(root, f) + archive_path = os.path.relpath(abspath, start=path) + zipf.write(abspath, arcname=archive_path) + zipf.close() + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +@override_settings(DEFAULT_DOMAIN='https://localhost') +@override_settings(LEDGER_SYNC_ENABLED=True) +class ObjectiveViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.objective_description, self.objective_description_filename, \ + self.objective_metrics, self.objective_metrics_filename = get_sample_objective() + + self.test_data_sample_keys = [ + "2d0f943aa81a9cb3fe84b162559ce6aff068ccb04e0cb284733b8f9d7e06517e", + "533ee6e7b9d8b247e7e853b24547f57e6ef351852bac0418f13a0666173448f1" + ] + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_objective_list_empty(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + def test_objective_list_success(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = objective + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [objective]) + + def test_objective_list_filter_fail(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = objective + + search_params = '?search=challenERRORge' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertIn('Malformed search filters', r['message']) + + def test_objective_list_filter_name(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = objective + + search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 2) + + def test_objective_list_filter_metrics(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = objective + + search_params = '?search=objective%253Ametrics%253Amacro-average%2520recall' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), len(objective)) + + def test_objective_list_filter_datamanager(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = objective + mquery_ledger2.return_value = datamanager + + search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_objective_list_filter_model(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.filters_utils.query_ledger') as mquery_ledger2: + mquery_ledger.return_value = objective + mquery_ledger2.return_value = traintuple + + pkhash = model[1]['traintuple']['outModel']['hash'] + search_params = f'?search=model%253Ahash%253A{pkhash}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_objective_retrieve(self): + url = reverse('substrapp:objective-list') + + with mock.patch('substrapp.views.objective.get_object_from_ledger') as mget_object_from_ledger, \ + mock.patch('substrapp.views.objective.get_remote_asset') as get_remote_asset: + mget_object_from_ledger.return_value = objective[0] + + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), + '../../../../fixtures/owkin/objectives/objective0/description.md'), 'rb') as f: + content = f.read() + + get_remote_asset.return_value = content + + pkhash = compute_hash(content) + search_params = f'{pkhash}/' + + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(r, objective[0]) + + def test_objective_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + url = reverse('substrapp:objective-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.objective.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('TEST') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_objective_create(self): + url = reverse('substrapp:objective-list') + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + objective_path = os.path.join(dir_path, '../../../../fixtures/owkin/objectives/objective0/') + + description_path = os.path.join(objective_path, 'description.md') + + metrics_path = os.path.join(MEDIA_ROOT, 'metrics.zip') + + zip_folder(objective_path, metrics_path) + + pkhash = get_hash(description_path) + + test_data_manager_key = get_hash(os.path.join( + dir_path, '../../../../fixtures/owkin/datamanagers/datamanager0/opener.py')) + + data = { + 'name': 'Simplified skin lesion classification', + 'description': open(description_path, 'rb'), + 'metrics_name': 'macro-average recall', + 'metrics': open(metrics_path, 'rb'), + 'permissions_public': True, + 'permissions_authorized_ids': [], + 'test_data_sample_keys': self.test_data_sample_keys, + 'test_data_manager_key': test_data_manager_key + } + + with mock.patch.object(LedgerObjectiveSerializer, 'create') as mcreate: + + mcreate.return_value = {} + + response = self.client.post(url, data=data, format='multipart', **self.extra) + + self.assertEqual(response.data['pkhash'], pkhash) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + data['description'].close() + data['metrics'].close() + + def test_objective_leaderboard_sort(self): + url = reverse('substrapp:objective-leaderboard', args=[objective[0]['key']]) + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger: + mquery_ledger.return_value = {} + + self.client.get(url, data={'sort': 'desc'}, **self.extra) + mquery_ledger.assert_called_with( + fcn='queryObjectiveLeaderboard', + args={ + 'objectiveKey': objective[0]['key'], + 'ascendingOrder': False, + }) + + self.client.get(url, data={'sort': 'asc'}, **self.extra) + mquery_ledger.assert_called_with( + fcn='queryObjectiveLeaderboard', + args={ + 'objectiveKey': objective[0]['key'], + 'ascendingOrder': True, + }) + + response = self.client.get(url, data={'sort': 'foo'}, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_objective_list_storage_addresses_update(self): + url = reverse('substrapp:objective-list') + with mock.patch('substrapp.views.objective.query_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.objective.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_objectives = copy.deepcopy(objective) + for ledger_objective in ledger_objectives: + for field in ('description', 'metrics'): + ledger_objective[field]['storageAddress'] = \ + ledger_objective[field]['storageAddress'] \ + .replace('http://testserver', 'http://remotetestserver') + mquery_ledger.return_value = ledger_objectives + + # actual test + res = self.client.get(url, **self.extra) + res_objectives = res.data[0] + self.assertEqual(len(res_objectives), len(objective)) + for i, res_objective in enumerate(res_objectives): + for field in ('description', 'metrics'): + self.assertEqual(res_objective[field]['storageAddress'], + objective[i][field]['storageAddress']) + + def test_objective_retrieve_storage_addresses_update_with_cache(self): + url = reverse('substrapp:objective-detail', args=[objective[0]['key']]) + with mock.patch('substrapp.views.objective.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.objective.node_has_process_permission', + return_value=True), \ + mock.patch('substrapp.views.objective.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_objective = copy.deepcopy(objective[0]) + for field in ('description', 'metrics'): + ledger_objective[field]['storageAddress'] = \ + ledger_objective[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_objective + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'metrics'): + self.assertEqual(res.data[field]['storageAddress'], + objective[0][field]['storageAddress']) + + def test_objective_retrieve_storage_addresses_update_without_cache(self): + url = reverse('substrapp:objective-detail', args=[objective[0]['key']]) + with mock.patch('substrapp.views.objective.get_object_from_ledger') as mquery_ledger, \ + mock.patch('substrapp.views.objective.node_has_process_permission', + return_value=False), \ + mock.patch('substrapp.views.objective.get_remote_asset') as mget_remote_asset: + + # mock content + mget_remote_asset.return_value = b'dummy binary content' + ledger_objective = copy.deepcopy(objective[0]) + for field in ('description', 'metrics'): + ledger_objective[field]['storageAddress'] = \ + ledger_objective[field]['storageAddress'].replace('http://testserver', + 'http://remotetestserver') + mquery_ledger.return_value = ledger_objective + + # actual test + res = self.client.get(url, **self.extra) + for field in ('description', 'metrics'): + self.assertEqual(res.data[field]['storageAddress'], + objective[0][field]['storageAddress']) diff --git a/backend/substrapp/tests/views/tests_views_task.py b/backend/substrapp/tests/views/tests_views_task.py new file mode 100644 index 000000000..ecb447ba0 --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_task.py @@ -0,0 +1,64 @@ +import os +import shutil +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from ..common import FakeAsyncResult, AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class TaskViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_task_retrieve(self): + + url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) + with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: + mAsyncResult.return_value = FakeAsyncResult(status='SUCCESS') + response = self.client.get(url, **self.extra) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_task_retrieve_fail(self): + url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) + with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: + mAsyncResult.return_value = FakeAsyncResult() + response = self.client.get(url, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_task_retrieve_pending(self): + url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) + with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: + mAsyncResult.return_value = FakeAsyncResult(status='PENDING', successful=False) + response = self.client.get(url, **self.extra) + self.assertEqual(response.data['message'], + 'Task is either waiting, does not exist in this context or has been removed after 24h') + + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/backend/substrapp/tests/views/tests_views_tuples.py b/backend/substrapp/tests/views/tests_views_tuples.py new file mode 100644 index 000000000..e8e3e646c --- /dev/null +++ b/backend/substrapp/tests/views/tests_views_tuples.py @@ -0,0 +1,213 @@ +import os +import shutil +import logging + +import mock + +from django.urls import reverse +from django.test import override_settings + +from rest_framework import status +from rest_framework.test import APITestCase + +from substrapp.views import TrainTupleViewSet, TestTupleViewSet + +from substrapp.utils import get_hash + +from substrapp.ledger_utils import LedgerError + +from ..assets import traintuple, testtuple +from ..common import AuthenticatedClient + +MEDIA_ROOT = "/tmp/unittests_views/" + + +def get_compute_plan_id(assets): + for asset in assets: + compute_plan_id = asset.get('computePlanID') + if compute_plan_id: + return compute_plan_id + raise Exception('Could not find a compute plan ID') + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class TraintupleViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_traintuple_queryset(self): + traintuple_view = TrainTupleViewSet() + self.assertFalse(traintuple_view.get_queryset()) + + def test_traintuple_list_empty(self): + url = reverse('substrapp:traintuple-list') + with mock.patch('substrapp.views.traintuple.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + def test_traintuple_retrieve(self): + + with mock.patch('substrapp.views.traintuple.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.return_value = traintuple[0] + url = reverse('substrapp:traintuple-list') + search_params = 'c164f4c714a78c7e2ba2016de231cdd41e3eac61289e08c1f711e74915a0868f/' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + self.assertEqual(r, traintuple[0]) + + def test_traintuple_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + url = reverse('substrapp:traintuple-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.traintuple.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('Test') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_traintuple_list_filter_tag(self): + url = reverse('substrapp:traintuple-list') + with mock.patch('substrapp.views.traintuple.query_ledger') as mquery_ledger: + mquery_ledger.return_value = traintuple + + search_params = '?search=traintuple%253Atag%253Asubstra' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + def test_traintuple_list_filter_compute_plan_id(self): + url = reverse('substrapp:traintuple-list') + with mock.patch('substrapp.views.traintuple.query_ledger') as mquery_ledger: + mquery_ledger.return_value = traintuple + compute_plan_id = get_compute_plan_id(traintuple) + search_params = f'?search=traintuple%253AcomputePlanID%253A{compute_plan_id}' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + +# APITestCase +@override_settings(MEDIA_ROOT=MEDIA_ROOT) +@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) +class TesttupleViewTests(APITestCase): + client_class = AuthenticatedClient + + def setUp(self): + if not os.path.exists(MEDIA_ROOT): + os.makedirs(MEDIA_ROOT) + + self.extra = { + 'HTTP_ACCEPT': 'application/json;version=0.0' + } + + self.logger = logging.getLogger('django.request') + self.previous_level = self.logger.getEffectiveLevel() + self.logger.setLevel(logging.ERROR) + + def tearDown(self): + shutil.rmtree(MEDIA_ROOT, ignore_errors=True) + + self.logger.setLevel(self.previous_level) + + def test_testtuple_queryset(self): + testtuple_view = TestTupleViewSet() + self.assertFalse(testtuple_view.get_queryset()) + + def test_testtuple_list_empty(self): + url = reverse('substrapp:testtuple-list') + with mock.patch('substrapp.views.testtuple.query_ledger') as mquery_ledger: + mquery_ledger.return_value = [] + + response = self.client.get(url, **self.extra) + r = response.json() + self.assertEqual(r, [[]]) + + def test_testtuple_retrieve(self): + + with mock.patch('substrapp.views.testtuple.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.return_value = testtuple[0] + url = reverse('substrapp:testtuple-list') + search_params = 'c164f4c714a78c7e2ba2016de231cdd41e3eac61289e08c1f711e74915a0868f/' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + self.assertEqual(r, testtuple[0]) + + def test_testtuple_retrieve_fail(self): + + dir_path = os.path.dirname(os.path.realpath(__file__)) + url = reverse('substrapp:testtuple-list') + + # PK hash < 64 chars + search_params = '42303efa663015e729159833a12ffb510ff/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # PK hash not hexa + search_params = 'X' * 64 + '/' + response = self.client.get(url + search_params, **self.extra) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + with mock.patch('substrapp.views.testtuple.get_object_from_ledger') as mget_object_from_ledger: + mget_object_from_ledger.side_effect = LedgerError('Test') + + file_hash = get_hash(os.path.join(dir_path, + "../../../../fixtures/owkin/objectives/objective0/description.md")) + search_params = f'{file_hash}/' + response = self.client.get(url + search_params, **self.extra) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_testtuple_list_filter_tag(self): + url = reverse('substrapp:testtuple-list') + with mock.patch('substrapp.views.testtuple.query_ledger') as mquery_ledger: + mquery_ledger.return_value = testtuple + + search_params = '?search=testtuple%253Atag%253Asubstra' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 1) + + search_params = '?search=testtuple%253Atag%253Afoo' + response = self.client.get(url + search_params, **self.extra) + r = response.json() + + self.assertEqual(len(r[0]), 0) diff --git a/substrabac/substrapp/urls.py b/backend/substrapp/urls.py similarity index 63% rename from substrabac/substrapp/urls.py rename to backend/substrapp/urls.py index 312ce188c..96436fa2a 100644 --- a/substrabac/substrapp/urls.py +++ b/backend/substrapp/urls.py @@ -6,19 +6,27 @@ from rest_framework.routers import DefaultRouter from substrapp.views import ObjectiveViewSet, DataSampleViewSet, DataManagerViewSet, \ - AlgoViewSet, TrainTupleViewSet, TestTupleViewSet, ModelViewSet, TaskViewSet + AlgoViewSet, TrainTupleViewSet, TestTupleViewSet, ModelViewSet, TaskViewSet, \ + ComputePlanViewSet, ObjectivePermissionViewSet, AlgoPermissionViewSet, DataManagerPermissionViewSet, \ + ModelPermissionViewSet + # Create a router and register our viewsets with it. router = DefaultRouter() router.register(r'objective', ObjectiveViewSet, base_name='objective') +router.register(r'objective', ObjectivePermissionViewSet, base_name='objective') router.register(r'model', ModelViewSet, base_name='model') +router.register(r'model', ModelPermissionViewSet, base_name='model') router.register(r'data_sample', DataSampleViewSet, base_name='data_sample') router.register(r'data_manager', DataManagerViewSet, base_name='data_manager') +router.register(r'data_manager', DataManagerPermissionViewSet, base_name='data_manager') router.register(r'algo', AlgoViewSet, base_name='algo') +router.register(r'algo', AlgoPermissionViewSet, base_name='algo') router.register(r'traintuple', TrainTupleViewSet, base_name='traintuple') router.register(r'testtuple', TestTupleViewSet, base_name='testtuple') router.register(r'task', TaskViewSet, base_name='task') +router.register(r'compute_plan', ComputePlanViewSet, base_name='compute_plan') urlpatterns = [ url(r'^', include(router.urls)), diff --git a/backend/substrapp/utils.py b/backend/substrapp/utils.py new file mode 100644 index 000000000..c144dd7ed --- /dev/null +++ b/backend/substrapp/utils.py @@ -0,0 +1,186 @@ +import io +import hashlib +import logging +import os +import tempfile +from os import path +from os.path import isfile, isdir +import shutil + +import requests +import tarfile +import zipfile +import uuid + +from checksumdir import dirhash + +from django.conf import settings +from rest_framework import status + + +class JsonException(Exception): + def __init__(self, msg): + self.msg = msg + super(JsonException, self).__init__() + + +def get_dir_hash(archive_object): + with tempfile.TemporaryDirectory() as temp_dir: + try: + content = archive_object.read() + archive_object.seek(0) + uncompress_content(content, temp_dir) + except Exception as e: + logging.error(e) + raise e + else: + return dirhash(temp_dir, 'sha256') + + +def store_datasamples_archive(archive_object): + + try: + content = archive_object.read() + archive_object.seek(0) + except Exception as e: + logging.error(e) + raise e + + # Temporary directory for uncompress + datasamples_uuid = uuid.uuid4().hex + tmp_datasamples_path = path.join(getattr(settings, 'MEDIA_ROOT'), + f'datasamples/{datasamples_uuid}') + try: + uncompress_content(content, tmp_datasamples_path) + except Exception as e: + shutil.rmtree(tmp_datasamples_path, ignore_errors=True) + logging.error(e) + raise e + else: + # return the directory hash of the uncompressed file and the path of + # the temporary directory. The removal should be handled externally. + return dirhash(tmp_datasamples_path, 'sha256'), tmp_datasamples_path + + +def get_hash(file, key=None): + if file is None: + return '' + else: + if isinstance(file, (str, bytes, os.PathLike)): + if isfile(file): + with open(file, 'rb') as f: + data = f.read() + elif isdir(file): + return dirhash(file, 'sha256') + else: + return '' + else: + openedfile = file.open() + data = openedfile.read() + openedfile.seek(0) + + return compute_hash(data, key) + + +def get_owner(): + ledger_settings = getattr(settings, 'LEDGER') + return ledger_settings['client']['msp_id'] + + +def compute_hash(bytes, key=None): + sha256_hash = hashlib.sha256() + + if isinstance(bytes, str): + bytes = bytes.encode() + + if key is not None and isinstance(key, str): + bytes += key.encode() + + sha256_hash.update(bytes) + + return sha256_hash.hexdigest() + + +def create_directory(directory): + if not os.path.exists(directory): + os.makedirs(directory) + + +class ZipFile(zipfile.ZipFile): + """Override Zipfile to ensure unix file permissions are preserved. + + This is due to a python bug: + https://bugs.python.org/issue15795 + + Workaround from: + https://stackoverflow.com/questions/39296101/python-zipfile-removes-execute-permissions-from-binaries + """ + + def extract(self, member, path=None, pwd=None): + if not isinstance(member, zipfile.ZipInfo): + member = self.getinfo(member) + + if path is None: + path = os.getcwd() + + ret_val = self._extract_member(member, path, pwd) + attr = member.external_attr >> 16 + os.chmod(ret_val, attr) + return ret_val + + +def uncompress_path(archive_path, to_directory): + if zipfile.is_zipfile(archive_path): + with ZipFile(archive_path, 'r') as zf: + zf.extractall(to_directory) + elif tarfile.is_tarfile(archive_path): + with tarfile.open(archive_path, 'r:*') as tf: + tf.extractall(to_directory) + else: + raise Exception('Archive must be zip or tar.gz') + + +def uncompress_content(archive_content, to_directory): + if zipfile.is_zipfile(io.BytesIO(archive_content)): + with ZipFile(io.BytesIO(archive_content)) as zf: + zf.extractall(to_directory) + else: + try: + with tarfile.open(fileobj=io.BytesIO(archive_content)) as tf: + tf.extractall(to_directory) + except tarfile.TarError: + raise Exception('Archive must be zip or tar.*') + + +class NodeError(Exception): + pass + + +def get_remote_file(url, auth, **kwargs): + kwargs.update({ + 'headers': {'Accept': 'application/json;version=0.0'}, + 'auth': auth + }) + + if settings.DEBUG: + kwargs['verify'] = False + + try: + response = requests.get(url, **kwargs) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + raise NodeError(f'Failed to fetch {url}') from e + + return response + + +def get_remote_file_content(url, auth, content_hash, salt=None): + response = get_remote_file(url, auth) + + if response.status_code != status.HTTP_200_OK: + logging.error(response.text) + raise NodeError(f'Url: {url} returned status code: {response.status_code}') + + computed_hash = compute_hash(response.content, key=salt) + if computed_hash != content_hash: + raise NodeError(f"url {url}: hash doesn't match {content_hash} vs {computed_hash}") + return response.content diff --git a/backend/substrapp/views/__init__.py b/backend/substrapp/views/__init__.py new file mode 100644 index 000000000..eb377cd2d --- /dev/null +++ b/backend/substrapp/views/__init__.py @@ -0,0 +1,16 @@ +# encoding: utf-8 + +from .datasample import DataSampleViewSet +from .datamanager import DataManagerViewSet, DataManagerPermissionViewSet +from .objective import ObjectiveViewSet, ObjectivePermissionViewSet +from .model import ModelViewSet, ModelPermissionViewSet +from .algo import AlgoViewSet, AlgoPermissionViewSet +from .traintuple import TrainTupleViewSet +from .testtuple import TestTupleViewSet +from .task import TaskViewSet +from .computeplan import ComputePlanViewSet + +__all__ = ['DataSampleViewSet', 'DataManagerViewSet', 'DataManagerPermissionViewSet', 'ObjectiveViewSet', + 'ObjectivePermissionViewSet', 'ModelViewSet', 'ModelPermissionViewSet', 'AlgoViewSet', + 'AlgoPermissionViewSet', 'TrainTupleViewSet', 'TestTupleViewSet', 'TaskViewSet', 'ComputePlanViewSet' + ] diff --git a/backend/substrapp/views/algo.py b/backend/substrapp/views/algo.py new file mode 100644 index 000000000..79c3180e9 --- /dev/null +++ b/backend/substrapp/views/algo.py @@ -0,0 +1,219 @@ +import tempfile +import logging + +from django.http import Http404 +from django.urls import reverse +from rest_framework import status, mixins +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.models import Algo +from substrapp.serializers import LedgerAlgoSerializer, AlgoSerializer +from substrapp.utils import get_hash +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError, LedgerTimeout, LedgerConflict +from substrapp.views.utils import (PermissionMixin, find_primary_key_error, + validate_pk, get_success_create_code, LedgerException, ValidationException, + get_remote_asset, node_has_process_permission) +from substrapp.views.filters_utils import filter_list + + +def replace_storage_addresses(request, algo): + algo['description']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:algo-description', args=[algo['key']])) + algo['content']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:algo-file', args=[algo['key']]) + ) + + +class AlgoViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + queryset = Algo.objects.all() + serializer_class = AlgoSerializer + ledger_query_call = 'queryAlgo' + + def perform_create(self, serializer): + return serializer.save() + + def commit(self, serializer, request): + # create on db + instance = self.perform_create(serializer) + + ledger_data = { + 'name': request.data.get('name'), + # XXX workaround because input is a QueryDict and not a JSON object. This + # is due to the fact that we are sending file object and body in a + # single HTTP request + 'permissions': { + 'public': request.data.get('permissions_public'), + 'authorized_ids': request.data.getlist('permissions_authorized_ids', []), + }, + } + + # init ledger serializer + ledger_data.update({'instance': instance}) + ledger_serializer = LedgerAlgoSerializer(data=ledger_data, + context={'request': request}) + if not ledger_serializer.is_valid(): + # delete instance + instance.delete() + raise ValidationError(ledger_serializer.errors) + + # create on ledger + try: + data = ledger_serializer.create(ledger_serializer.validated_data) + except LedgerTimeout as e: + data = {'pkhash': [x['pkhash'] for x in serializer.data], 'validated': False} + raise LedgerException(data, e.status) + except LedgerConflict as e: + raise ValidationException(e.msg, e.pkhash, e.status) + except LedgerError as e: + instance.delete() + raise LedgerException(str(e.msg), e.status) + except Exception: + instance.delete() + raise + + d = dict(serializer.data) + d.update(data) + + return d + + def _create(self, request, file): + + pkhash = get_hash(file) + serializer = self.get_serializer(data={ + 'pkhash': pkhash, + 'file': file, + 'description': request.data.get('description') + }) + + try: + serializer.is_valid(raise_exception=True) + except Exception as e: + st = status.HTTP_400_BAD_REQUEST + if find_primary_key_error(e): + st = status.HTTP_409_CONFLICT + raise ValidationException(e.args, pkhash, st) + else: + # create on ledger + db + return self.commit(serializer, request) + + def create(self, request, *args, **kwargs): + file = request.data.get('file') + + try: + data = self._create(request, file) + except ValidationException as e: + return Response({'message': e.data, 'pkhash': e.pkhash}, status=e.st) + except LedgerException as e: + return Response({'message': e.data}, status=e.st) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + headers = self.get_success_headers(data) + st = get_success_create_code() + return Response(data, status=st, headers=headers) + + def create_or_update_algo(self, algo, pk): + # get algo description from remote node + url = algo['description']['storageAddress'] + + content = get_remote_asset(url, algo['owner'], algo['description']['hash']) + + f = tempfile.TemporaryFile() + f.write(content) + + # save/update objective in local db for later use + instance, created = Algo.objects.update_or_create(pkhash=pk, validated=True) + instance.description.save('description.md', f) + + return instance + + def _retrieve(self, request, pk): + validate_pk(pk) + data = get_object_from_ledger(pk, self.ledger_query_call) + + # do not cache if node has not process permission + if node_has_process_permission(data): + # try to get it from local db to check if description exists + try: + instance = self.get_object() + except Http404: + instance = None + finally: + # check if instance has description + if not instance or not instance.description: + instance = self.create_or_update_algo(data, pk) + + # For security reason, do not give access to local file address + # Restrain data to some fields + # TODO: do we need to send creation date and/or last modified date ? + serializer = self.get_serializer(instance, fields=('owner', 'pkhash')) + data.update(serializer.data) + + replace_storage_addresses(request, data) + + return data + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(request, pk) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + return Response({'message': str(e)}, status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryAlgos', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + algos_list = [data] + + # parse filters + query_params = request.query_params.get('search', None) + + if query_params is not None: + try: + algos_list = filter_list( + object_type='algo', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + for group in algos_list: + for algo in group: + replace_storage_addresses(request, algo) + + return Response(algos_list, status=status.HTTP_200_OK) + + +class AlgoPermissionViewSet(PermissionMixin, + GenericViewSet): + queryset = Algo.objects.all() + serializer_class = AlgoSerializer + ledger_query_call = 'queryAlgo' + + @action(detail=True) + def file(self, request, *args, **kwargs): + return self.download_file(request, 'file', 'content') + + @action(detail=True) + def description(self, request, *args, **kwargs): + return self.download_file(request, 'description') diff --git a/backend/substrapp/views/computeplan.py b/backend/substrapp/views/computeplan.py new file mode 100644 index 000000000..f01790373 --- /dev/null +++ b/backend/substrapp/views/computeplan.py @@ -0,0 +1,39 @@ +from rest_framework import mixins +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.serializers import LedgerComputePlanSerializer +from substrapp.ledger_utils import query_ledger, LedgerError +from substrapp.views.utils import get_success_create_code + + +class ComputePlanViewSet(mixins.CreateModelMixin, + GenericViewSet): + + serializer_class = LedgerComputePlanSerializer + + def create(self, request, *args, **kwargs): + # rely on serializer to parse and validate request data + serializer = self.get_serializer(data=dict(request.data)) + serializer.is_valid(raise_exception=True) + + # get compute_plan_id to handle 408 timeout in next invoke ledger request + args = serializer.get_args(serializer.validated_data) + try: + ledger_response = query_ledger(fcn='createComputePlan', args=args) + except LedgerError as e: + error = {'message': str(e.msg)} + return Response(error, status=e.status) + + # create compute plan in ledger + compute_plan_id = ledger_response.get('computePlanID') + try: + data = serializer.create(serializer.validated_data) + except LedgerError as e: + error = {'message': str(e.msg), 'computePlanID': compute_plan_id} + return Response(error, status=e.status) + + # send successful response + headers = self.get_success_headers(data) + status = get_success_create_code() + return Response(data, status=status, headers=headers) diff --git a/backend/substrapp/views/datamanager.py b/backend/substrapp/views/datamanager.py new file mode 100644 index 000000000..982e6fa1a --- /dev/null +++ b/backend/substrapp/views/datamanager.py @@ -0,0 +1,277 @@ +import tempfile +import logging +from django.conf import settings +from django.http import Http404 +from django.urls import reverse +from rest_framework import status, mixins +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +# from hfc.fabric import Client +# cli = Client(net_profile="../network.json") +from substrapp.models import DataManager +from substrapp.serializers import DataManagerSerializer, LedgerDataManagerSerializer +from substrapp.serializers.ledger.datamanager.util import updateLedgerDataManager +from substrapp.serializers.ledger.datamanager.tasks import updateLedgerDataManagerAsync +from substrapp.utils import get_hash +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError, LedgerTimeout, LedgerConflict +from substrapp.views.utils import (PermissionMixin, find_primary_key_error, + validate_pk, get_success_create_code, ValidationException, LedgerException, + get_remote_asset, node_has_process_permission) +from substrapp.views.filters_utils import filter_list + + +def replace_storage_addresses(request, data_manager): + data_manager['description']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:data_manager-description', args=[data_manager['key']])) + data_manager['opener']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:data_manager-opener', args=[data_manager['key']]) + ) + + +class DataManagerViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + queryset = DataManager.objects.all() + serializer_class = DataManagerSerializer + ledger_query_call = 'queryDataManager' + + def perform_create(self, serializer): + return serializer.save() + + def commit(self, serializer, request): + # create on ledger + db + ledger_data = { + 'name': request.data.get('name'), + # XXX workaround because input is a QueryDict and not a JSON object. This + # is due to the fact that we are sending file object and body in a + # single HTTP request + 'permissions': { + 'public': request.data.get('permissions_public'), + 'authorized_ids': request.data.getlist('permissions_authorized_ids', []), + }, + 'type': request.data.get('type'), + 'objective_keys': request.data.getlist('objective_keys'), + } + + # create on db + instance = self.perform_create(serializer) + # init ledger serializer + ledger_data.update({'instance': instance}) + ledger_serializer = LedgerDataManagerSerializer(data=ledger_data, + context={'request': request}) + + if not ledger_serializer.is_valid(): + # delete instance + instance.delete() + raise ValidationError(ledger_serializer.errors) + + # create on ledger + try: + data = ledger_serializer.create(ledger_serializer.validated_data) + except LedgerTimeout as e: + data = {'pkhash': [x['pkhash'] for x in serializer.data], 'validated': False} + raise LedgerException(data, e.status) + except LedgerConflict as e: + raise ValidationException(e.msg, e.pkhash, e.status) + except LedgerError as e: + instance.delete() + raise LedgerException(str(e.msg), e.status) + except Exception: + instance.delete() + raise + + d = dict(serializer.data) + d.update(data) + + return d + + def _create(self, request, data_opener): + pkhash = get_hash(data_opener) + serializer = self.get_serializer(data={ + 'pkhash': pkhash, + 'data_opener': data_opener, + 'description': request.data.get('description'), + 'name': request.data.get('name'), + }) + + try: + serializer.is_valid(raise_exception=True) + except Exception as e: + st = status.HTTP_400_BAD_REQUEST + if find_primary_key_error(e): + st = status.HTTP_409_CONFLICT + raise ValidationException(e.args, pkhash, st) + else: + # create on ledger + db + return self.commit(serializer, request) + + def create(self, request, *args, **kwargs): + data_opener = request.data.get('data_opener') + + try: + data = self._create(request, data_opener) + except ValidationException as e: + return Response({'message': e.data, 'pkhash': e.pkhash}, status=e.st) + except LedgerException as e: + return Response({'message': e.data}, status=e.st) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + headers = self.get_success_headers(data) + st = get_success_create_code() + return Response(data, status=st, headers=headers) + + def create_or_update_datamanager(self, instance, datamanager, pk): + + # create instance if does not exist + if not instance: + instance, created = DataManager.objects.update_or_create( + pkhash=pk, name=datamanager['name'], validated=True) + + if not instance.data_opener: + url = datamanager['opener']['storageAddress'] + + content = get_remote_asset(url, datamanager['owner'], datamanager['opener']['hash']) + + f = tempfile.TemporaryFile() + f.write(content) + + # save/update data_opener in local db for later use + instance.data_opener.save('opener.py', f) + + # do the same for description + if not instance.description: + url = datamanager['description']['storageAddress'] + + content = get_remote_asset(url, datamanager['owner'], datamanager['description']['hash']) + + f = tempfile.TemporaryFile() + f.write(content) + + # save/update description in local db for later use + instance.description.save('description.md', f) + + return instance + + def _retrieve(self, request, pk): + validate_pk(pk) + # get instance from remote node + data = get_object_from_ledger(pk, 'queryDataset') + + # do not cache if node has not process permission + if node_has_process_permission(data): + # try to get it from local db to check if description exists + try: + instance = self.get_object() + except Http404: + instance = None + finally: + # check if instance has description or data_opener + if not instance or not instance.description or not instance.data_opener: + instance = self.create_or_update_datamanager(instance, data, pk) + + # do not give access to local files address + serializer = self.get_serializer(instance, fields=('owner', 'pkhash', 'creation_date', 'last_modified')) + data.update(serializer.data) + + replace_storage_addresses(request, data) + + return data + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(request, pk) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) + + def list(self, request, *args, **kwargs): + + try: + data = query_ledger(fcn='queryDataManagers', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + data_managers_list = [data] + + # parse filters + query_params = request.query_params.get('search', None) + + if query_params is not None: + try: + data_managers_list = filter_list( + object_type='dataset', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + for group in data_managers_list: + for data_manager in group: + replace_storage_addresses(request, data_manager) + + return Response(data_managers_list, status=status.HTTP_200_OK) + + @action(methods=['post'], detail=True) + def update_ledger(self, request, *args, **kwargs): + + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + validate_pk(pk) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + + objective_key = request.data.get('objective_key') + args = { + 'dataManagerKey': pk, + 'objectiveKey': objective_key, + } + + if getattr(settings, 'LEDGER_SYNC_ENABLED'): + try: + data = updateLedgerDataManager(args, sync=True) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + st = status.HTTP_200_OK + + else: + # use a celery task, as we are in an http request transaction + updateLedgerDataManagerAsync.delay(args) + data = { + 'message': 'The substra network has been notified for updating this DataManager' + } + st = status.HTTP_202_ACCEPTED + + return Response(data, status=st) + + +class DataManagerPermissionViewSet(PermissionMixin, + GenericViewSet): + queryset = DataManager.objects.all() + serializer_class = DataManagerSerializer + ledger_query_call = 'queryDataManager' + + @action(detail=True) + def description(self, request, *args, **kwargs): + return self.download_file(request, 'description') + + @action(detail=True) + def opener(self, request, *args, **kwargs): + return self.download_file(request, 'data_opener', 'opener') diff --git a/backend/substrapp/views/datasample.py b/backend/substrapp/views/datasample.py new file mode 100644 index 000000000..007fc1b2b --- /dev/null +++ b/backend/substrapp/views/datasample.py @@ -0,0 +1,273 @@ +import logging +from os.path import normpath + +import os +import ntpath +import shutil + +from checksumdir import dirhash +from django.conf import settings +from rest_framework import status, mixins +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.fields import BooleanField +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.models import DataSample, DataManager +from substrapp.serializers import DataSampleSerializer, LedgerDataSampleSerializer +from substrapp.serializers.ledger.datasample.util import updateLedgerDataSample +from substrapp.serializers.ledger.datasample.tasks import updateLedgerDataSampleAsync +from substrapp.utils import store_datasamples_archive +from substrapp.views.utils import find_primary_key_error, LedgerException, ValidationException, \ + get_success_create_code +from substrapp.ledger_utils import query_ledger, LedgerError, LedgerTimeout, LedgerConflict + +logger = logging.getLogger('django.request') + + +class DataSampleViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + queryset = DataSample.objects.all() + serializer_class = DataSampleSerializer + + @staticmethod + def check_datamanagers(data_manager_keys): + datamanager_count = DataManager.objects.filter(pkhash__in=data_manager_keys).count() + + if datamanager_count != len(data_manager_keys): + raise Exception(f'One or more datamanager keys provided do not exist in local database. ' + f'Please create them before. DataManager keys: {data_manager_keys}') + + @staticmethod + def commit(serializer, ledger_data): + instances = serializer.save() + # init ledger serializer + ledger_data.update({'instances': instances}) + ledger_serializer = LedgerDataSampleSerializer(data=ledger_data) + + if not ledger_serializer.is_valid(): + # delete instance + for instance in instances: + instance.delete() + raise ValidationError(ledger_serializer.errors) + + # create on ledger + try: + data = ledger_serializer.create(ledger_serializer.validated_data) + except LedgerTimeout as e: + data = {'pkhash': [x['pkhash'] for x in serializer.data], 'validated': False} + raise LedgerException(data, e.status) + except LedgerConflict as e: + raise ValidationException(e.msg, e.pkhash, e.status) + except LedgerError as e: + for instance in instances: + instance.delete() + raise LedgerException(str(e.msg), e.status) + except Exception: + for instance in instances: + instance.delete() + raise + + st = get_success_create_code() + + # update validated to True in response + if 'pkhash' in data and data['validated']: + for d in serializer.data: + if d['pkhash'] in data['pkhash']: + d.update({'validated': data['validated']}) + + return serializer.data, st + + def compute_data(self, request, paths_to_remove): + + data = {} + + # files can be uploaded inside the HTTP request or can already be + # available on local disk + if len(request.FILES) > 0: + pkhash_map = {} + + for k, file in request.FILES.items(): + # Get dir hash uncompress the file into a directory + pkhash, datasamples_path_from_file = store_datasamples_archive(file) # can raise + paths_to_remove.append(datasamples_path_from_file) + # check pkhash does not belong to the list + try: + data[pkhash] + except KeyError: + pkhash_map[pkhash] = file + else: + raise Exception(f'Your data sample archives contain same files leading to same pkhash, ' + f'please review the content of your achives. ' + f'Archives {file} and {pkhash_map[pkhash]} are the same') + data[pkhash] = { + 'pkhash': pkhash, + 'path': datasamples_path_from_file + } + + else: # files must be available on local filesystem + path = request.POST.get('path') + paths = request.POST.getlist('paths') + + if path and paths: + raise Exception('Cannot use path and paths together.') + if path is not None: + paths = [path] + + recursive_dir_field = BooleanField() + recursive_dir = recursive_dir_field.to_internal_value(request.data.get('multiple', 'false')) + if recursive_dir: + # list all directories from parent directories + parent_paths = paths + paths = [] + for parent_path in parent_paths: + subdirs = next(os.walk(parent_path))[1] + subdirs = [os.path.join(parent_path, s) for s in subdirs] + if not subdirs: + raise Exception( + f'No data sample directories in folder {parent_path}') + paths.extend(subdirs) + + # paths, should be directories + for path in paths: + if not os.path.isdir(path): + raise Exception(f'One of your paths does not exist, ' + f'is not a directory or is not an absolute path: {path}') + pkhash = dirhash(path, 'sha256') + try: + data[pkhash] + except KeyError: + pass + else: + # existing can be a dict with a field path or file + raise Exception(f'Your data sample directory contain same files leading to same pkhash. ' + f'Invalid path: {path}.') + + data[pkhash] = { + 'pkhash': pkhash, + 'path': normpath(path) + } + + if not data: + raise Exception(f'No data sample provided.') + + return list(data.values()) + + def _create(self, request, data_manager_keys, test_only): + + # compute_data will uncompress data archives to paths which will be + # hardlinked thanks to datasample pre_save signal. + # In all other cases, we need to remove those references. + + if not data_manager_keys: + raise Exception("missing or empty field 'data_manager_keys'") + + self.check_datamanagers(data_manager_keys) # can raise + + paths_to_remove = [] + + try: + # will uncompress data archives to paths + computed_data = self.compute_data(request, paths_to_remove) + + serializer = self.get_serializer(data=computed_data, many=True) + + try: + serializer.is_valid(raise_exception=True) + except Exception as e: + pkhashes = [x['pkhash'] for x in computed_data] + st = status.HTTP_400_BAD_REQUEST + if find_primary_key_error(e): + st = status.HTTP_409_CONFLICT + raise ValidationException(e.args, pkhashes, st) + else: + + # create on ledger + db + ledger_data = {'test_only': test_only, + 'data_manager_keys': data_manager_keys} + data, st = self.commit(serializer, ledger_data) # pre_save signal executed + return data, st + finally: + for gpath in paths_to_remove: + shutil.rmtree(gpath, ignore_errors=True) + + def create(self, request, *args, **kwargs): + test_only = request.data.get('test_only', False) + data_manager_keys = request.data.getlist('data_manager_keys', []) + + try: + data, st = self._create(request, data_manager_keys, test_only) + except ValidationException as e: + return Response({'message': e.data, 'pkhash': e.pkhash}, status=e.st) + except LedgerException as e: + return Response({'message': e.data}, status=e.st) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + headers = self.get_success_headers(data) + return Response(data, status=st, headers=headers) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryDataSamples', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + data = data if data else [] + + return Response(data, status=status.HTTP_200_OK) + + def validate_bulk_update(self, data): + try: + data_manager_keys = data.getlist('data_manager_keys') + except KeyError: + data_manager_keys = [] + if not data_manager_keys: + raise Exception('Please pass a non empty data_manager_keys key param') + + try: + data_sample_keys = data.getlist('data_sample_keys') + except KeyError: + data_sample_keys = [] + if not data_sample_keys: + raise Exception('Please pass a non empty data_sample_keys key param') + + return data_manager_keys, data_sample_keys + + @action(methods=['post'], detail=False) + def bulk_update(self, request): + try: + data_manager_keys, data_sample_keys = self.validate_bulk_update(request.data) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + args = { + 'hashes': data_sample_keys, + 'dataManagerKeys': data_manager_keys, + } + + if getattr(settings, 'LEDGER_SYNC_ENABLED'): + try: + data = updateLedgerDataSample(args, sync=True) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.st) + + st = status.HTTP_200_OK + + else: + # use a celery task, as we are in an http request transaction + updateLedgerDataSampleAsync.delay(args) + data = { + 'message': 'The substra network has been notified for updating these Data' + } + st = status.HTTP_202_ACCEPTED + + return Response(data, status=st) + + +def path_leaf(path): + head, tail = ntpath.split(path) + return tail or ntpath.basename(head) diff --git a/backend/substrapp/views/filters_utils.py b/backend/substrapp/views/filters_utils.py new file mode 100644 index 000000000..509c17825 --- /dev/null +++ b/backend/substrapp/views/filters_utils.py @@ -0,0 +1,162 @@ +from urllib.parse import unquote + +from substrapp.ledger_utils import query_ledger + + +FILTER_QUERIES = { + 'dataset': 'queryDataManagers', + 'algo': 'queryAlgos', + 'objective': 'queryObjectives', + 'model': 'queryTraintuples', +} + +AUTHORIZED_FILTERS = { + 'dataset': ['dataset', 'model', 'objective'], + 'algo': ['model', 'algo'], + 'objective': ['model', 'dataset', 'objective'], + 'model': ['model', 'algo', 'dataset', 'objective'], + 'traintuple': ['traintuple'], + 'testtuple': ['testtuple'], +} + + +def get_filters(query_params): + filters = [] + groups = query_params.split('-OR-') + + for idx, group in enumerate(groups): + + # init + filters.append({}) + + # get number of subfilters and decode them + subfilters = [unquote(x) for x in group.split(',')] + + for subfilter in subfilters: + el = subfilter.split(':') + # get parent + parent = el[0] + subparent = el[1] + value = el[2] + + filter = { + subparent: [unquote(value)] + } + + if not len(filters[idx]): # create and add it + filters[idx] = { + parent: filter + } + else: # add it + if parent in filters[idx]: # add + if el[1] in filters[idx][parent]: # concat in subparent + filters[idx][parent][subparent].extend([value]) + else: # add new subparent + filters[idx][parent].update(filter) + else: # create + filters[idx].update({parent: filter}) + + return filters + + +def filter_list(object_type, data, query_params): + + filters = get_filters(query_params) + + object_list = [] + + for user_filter in filters: + + for filter_key, subfilters in user_filter.items(): + + if filter_key not in AUTHORIZED_FILTERS[object_type]: + raise Exception(f'Not authorized filter key {filter_key} for asset {object_type}') + + # Will be appended in object_list after been filtered + filtered_list = data + + if filter_key == object_type: + # Filter by own asset + if filter_key == 'model': + for attribute, val in subfilters.items(): + filtered_list = [x for x in filtered_list + if x['traintuple']['outModel'] is not None and + x['traintuple']['outModel']['hash'] in val] + elif filter_key == 'objective': + for attribute, val in subfilters.items(): + if attribute == 'metrics': # specific to nested metrics + filtered_list = [x for x in filtered_list if x[attribute]['name'] in val] + else: + filtered_list = [x for x in filtered_list if x[attribute] in val] + + else: + for attribute, val in subfilters.items(): + filtered_list = [x for x in filtered_list if x.get(attribute) in val] + else: + # Filter by other asset + + # Get other asset list + filtering_data = query_ledger(fcn=FILTER_QUERIES[filter_key], args=[]) + + filtering_data = filtering_data if filtering_data else [] + + if filter_key == 'algo': + for attribute, val in subfilters.items(): + filtering_data = [x for x in filtering_data if x[attribute] in val] + hashes = [x['key'] for x in filtering_data] + + if object_type == 'model': + filtered_list = [x for x in filtered_list + if x['traintuple']['algo']['hash'] in hashes] + + elif filter_key == 'model': + for attribute, val in subfilters.items(): + filtering_data = [x for x in filtering_data + if x['outModel'] is not None and x['outModel'][attribute] in val] + + if object_type == 'algo': + hashes = [x['algo']['hash'] for x in filtering_data] + filtered_list = [x for x in filtered_list if x['key'] in hashes] + + elif object_type == 'dataset': + hashes = [x['objective']['hash'] for x in filtering_data] + filtered_list = [x for x in filtered_list + if x['objectiveKey'] in hashes] + + elif object_type == 'objective': + hashes = [x['objective']['hash'] for x in filtering_data] + filtered_list = [x for x in filtered_list if x['key'] in hashes] + + elif filter_key == 'dataset': + for attribute, val in subfilters.items(): + filtering_data = [x for x in filtering_data if x[attribute] in val] + hashes = [x['key'] for x in filtering_data] + + if object_type == 'model': + filtered_list = [x for x in filtered_list + if x['traintuple']['dataset']['openerHash'] in hashes] + elif object_type == 'objective': + objectiveKeys = [x['objectiveKey'] for x in filtering_data] + filtered_list = [x for x in filtered_list + if x['key'] in objectiveKeys or + (x['testDataset'] and x['testDataset']['dataManagerKey'] in hashes)] + + elif filter_key == 'objective': + for attribute, val in subfilters.items(): + if attribute == 'metrics': # specific to nested metrics + filtering_data = [x for x in filtering_data if x[attribute]['name'] in val] + else: + filtering_data = [x for x in filtering_data if x[attribute] in val] + + hashes = [x['key'] for x in filtering_data] + + if object_type == 'model': + filtered_list = [x for x in filtered_list + if x['traintuple']['objective']['hash'] in hashes] + elif object_type == 'dataset': + filtered_list = [x for x in filtered_list + if x['objectiveKey'] in hashes] + + object_list.append(filtered_list) + + return object_list diff --git a/backend/substrapp/views/model.py b/backend/substrapp/views/model.py new file mode 100644 index 000000000..171a31e99 --- /dev/null +++ b/backend/substrapp/views/model.py @@ -0,0 +1,139 @@ +import os +import tempfile +import logging +from django.http import Http404 +from rest_framework import status, mixins +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from node.authentication import NodeUser +from substrapp.models import Model +from substrapp.serializers import ModelSerializer +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError +from substrapp.views.utils import CustomFileResponse, validate_pk, get_remote_asset, PermissionMixin +from substrapp.views.filters_utils import filter_list + + +class ModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + queryset = Model.objects.all() + serializer_class = ModelSerializer + ledger_query_call = 'queryModelDetails' + # permission_classes = (permissions.IsAuthenticated,) + + def create_or_update_model(self, traintuple, pk): + if traintuple['outModel'] is None: + raise Exception(f'This traintuple related to this model key {pk} does not have a outModel') + + # get model from remote node + url = traintuple['outModel']['storageAddress'] + + content = get_remote_asset(url, traintuple['creator'], traintuple['key']) + + # write model in local db for later use + tmp_model = tempfile.TemporaryFile() + tmp_model.write(content) + instance, created = Model.objects.update_or_create(pkhash=pk, validated=True) + instance.file.save('model', tmp_model) + + return instance + + def _retrieve(self, pk): + validate_pk(pk) + + data = get_object_from_ledger(pk, self.ledger_query_call) + if not data or not data.get('traintuple'): + raise Exception('Invalid model: missing traintuple field') + if data['traintuple'].get('status') != "done": + raise Exception("Invalid model: traintuple must be at status done") + + # Try to get it from local db, else create it in local db + try: + instance = self.get_object() + except Http404: + instance = None + + if not instance or not instance.file: + instance = self.create_or_update_model(data['traintuple'], + data['traintuple']['outModel']['hash']) + + # For security reason, do not give access to local file address + # Restrain data to some fields + # TODO: do we need to send creation date and/or last modified date ? + serializer = self.get_serializer(instance, fields=('owner', 'pkhash')) + data.update(serializer.data) + + return data + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(pk) + except LedgerError as e: + logging.exception(e) + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response({'message': str(e)}, status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryModels', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + models_list = [data] + + query_params = request.query_params.get('search', None) + if query_params is not None: + try: + models_list = filter_list( + object_type='model', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + return Response(models_list, status=status.HTTP_200_OK) + + @action(detail=True) + def details(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = get_object_from_ledger(pk, self.ledger_query_call) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + return Response(data, status=status.HTTP_200_OK) + + +class ModelPermissionViewSet(PermissionMixin, + GenericViewSet): + + queryset = Model.objects.all() + serializer_class = ModelSerializer + ledger_query_call = 'queryModelDetails' + + @action(detail=True) + def file(self, request, *args, **kwargs): + + # user cannot download model, only node can + if not isinstance(request.user, NodeUser): + return Response({}, status=status.HTTP_403_FORBIDDEN) + + model_object = self.get_object() + data = getattr(model_object, 'file') + return CustomFileResponse(open(data.path, 'rb'), as_attachment=True, filename=os.path.basename(data.path)) diff --git a/backend/substrapp/views/objective.py b/backend/substrapp/views/objective.py new file mode 100644 index 000000000..7aaa43e59 --- /dev/null +++ b/backend/substrapp/views/objective.py @@ -0,0 +1,265 @@ +import logging +import re +import tempfile + +from django.db import IntegrityError +from django.http import Http404 +from django.urls import reverse +from rest_framework import status, mixins +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + + +from substrapp.models import Objective +from substrapp.serializers import ObjectiveSerializer, LedgerObjectiveSerializer + +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError, LedgerTimeout, LedgerConflict +from substrapp.utils import get_hash +from substrapp.views.utils import (PermissionMixin, find_primary_key_error, validate_pk, + get_success_create_code, ValidationException, + LedgerException, get_remote_asset, validate_sort, + node_has_process_permission) +from substrapp.views.filters_utils import filter_list + + +def replace_storage_addresses(request, objective): + objective['description']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:objective-description', args=[objective['key']])) + objective['metrics']['storageAddress'] = request.build_absolute_uri( + reverse('substrapp:objective-metrics', args=[objective['key']]) + ) + + +class ObjectiveViewSet(mixins.CreateModelMixin, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + GenericViewSet): + queryset = Objective.objects.all() + serializer_class = ObjectiveSerializer + ledger_query_call = 'queryObjective' + + def perform_create(self, serializer): + return serializer.save() + + def commit(self, serializer, request): + # create on local db + try: + instance = self.perform_create(serializer) + except IntegrityError as e: + try: + pkhash = re.search(r'\(pkhash\)=\((\w+)\)', e.args[0]).group(1) + except IndexError: + pkhash = '' + err_msg = 'A objective with this description file already exists.' + return {'message': err_msg, 'pkhash': pkhash}, status.HTTP_409_CONFLICT + except Exception as e: + raise Exception(e.args) + + # init ledger serializer + ledger_data = { + 'test_data_sample_keys': request.data.getlist('test_data_sample_keys', []), + 'test_data_manager_key': request.data.get('test_data_manager_key', ''), + 'name': request.data.get('name'), + # XXX workaround because input is a QueryDict and not a JSON object. This + # is due to the fact that we are sending file object and body in a + # single HTTP request + 'permissions': { + 'public': request.data.get('permissions_public'), + 'authorized_ids': request.data.getlist('permissions_authorized_ids', []), + }, + 'metrics_name': request.data.get('metrics_name'), + } + ledger_data.update({'instance': instance}) + ledger_serializer = LedgerObjectiveSerializer(data=ledger_data, + context={'request': request}) + + if not ledger_serializer.is_valid(): + # delete instance + instance.delete() + raise ValidationError(ledger_serializer.errors) + + # create on ledger + try: + data = ledger_serializer.create(ledger_serializer.validated_data) + except LedgerTimeout as e: + data = {'pkhash': [x['pkhash'] for x in serializer.data], 'validated': False} + raise LedgerException(data, e.status) + except LedgerConflict as e: + raise ValidationException(e.msg, e.pkhash, e.status) + except LedgerError as e: + instance.delete() + raise LedgerException(str(e.msg), e.status) + except Exception: + instance.delete() + raise + + d = dict(serializer.data) + d.update(data) + + return d + + def _create(self, request): + metrics = request.data.get('metrics') + description = request.data.get('description') + + pkhash = get_hash(description) + + serializer = self.get_serializer(data={ + 'pkhash': pkhash, + 'metrics': metrics, + 'description': description, + }) + + try: + serializer.is_valid(raise_exception=True) + except Exception as e: + st = status.HTTP_400_BAD_REQUEST + if find_primary_key_error(e): + st = status.HTTP_409_CONFLICT + raise ValidationException(e.args, pkhash, st) + else: + # create on ledger + db + return self.commit(serializer, request) + + def create(self, request, *args, **kwargs): + + try: + data = self._create(request) + except ValidationException as e: + return Response({'message': e.data, 'pkhash': e.pkhash}, status=e.st) + except LedgerException as e: + return Response({'message': e.data}, status=e.st) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + else: + headers = self.get_success_headers(data) + st = get_success_create_code() + return Response(data, status=st, headers=headers) + + def create_or_update_objective(self, objective, pk): + # get description from remote node + url = objective['description']['storageAddress'] + + content = get_remote_asset(url, objective['owner'], pk) + + # write objective with description in local db for later use + tmp_description = tempfile.TemporaryFile() + tmp_description.write(content) + instance, created = Objective.objects.update_or_create(pkhash=pk, validated=True) + instance.description.save('description.md', tmp_description) + return instance + + def _retrieve(self, request, pk): + validate_pk(pk) + # get instance from remote node + data = get_object_from_ledger(pk, self.ledger_query_call) + + # do not cache if node has not process permission + if node_has_process_permission(data): + # try to get it from local db to check if description exists + try: + instance = self.get_object() + except Http404: + instance = None + + if not instance or not instance.description: + instance = self.create_or_update_objective(data, pk) + + # For security reason, do not give access to local file address + # Restrain data to some fields + # TODO: do we need to send creation date and/or last modified date ? + serializer = self.get_serializer(instance, fields=('owner', 'pkhash')) + data.update(serializer.data) + + replace_storage_addresses(request, data) + + return data + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(request, pk) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + return Response({'message': str(e)}, status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryObjectives', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + objectives_list = [data] + + query_params = request.query_params.get('search', None) + if query_params is not None: + try: + objectives_list = filter_list( + object_type='objective', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + for group in objectives_list: + for objective in group: + replace_storage_addresses(request, objective) + + return Response(objectives_list, status=status.HTTP_200_OK) + + @action(detail=True) + def data(self, request, *args, **kwargs): + instance = self.get_object() + + # TODO fetch list of data from ledger + # query list of related algos and models from ledger + + serializer = self.get_serializer(instance) + return Response(serializer.data) + + @action(detail=True, methods=['GET']) + def leaderboard(self, request, pk): + validate_pk(pk) + + sort = request.query_params.get('sort', 'desc') + try: + validate_sort(sort) + except Exception as e: + return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) + + try: + leaderboard = query_ledger(fcn='queryObjectiveLeaderboard', args={ + 'objectiveKey': pk, + 'ascendingOrder': sort == 'asc', + }) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + return Response(leaderboard, status=status.HTTP_200_OK) + + +class ObjectivePermissionViewSet(PermissionMixin, + GenericViewSet): + queryset = Objective.objects.all() + serializer_class = ObjectiveSerializer + ledger_query_call = 'queryObjective' + + @action(detail=True) + def description(self, request, *args, **kwargs): + return self.download_file(request, 'description') + + @action(detail=True) + def metrics(self, request, *args, **kwargs): + return self.download_file(request, 'metrics') diff --git a/substrabac/substrapp/views/task.py b/backend/substrapp/views/task.py similarity index 81% rename from substrabac/substrapp/views/task.py rename to backend/substrapp/views/task.py index d7717ac91..00a690452 100644 --- a/substrabac/substrapp/views/task.py +++ b/backend/substrapp/views/task.py @@ -15,12 +15,13 @@ def retrieve(self, request, pk=None): data = { 'status': res.status } - except: + except Exception: return Response({'message': 'Can\'t get task status'}, status=status.HTTP_400_BAD_REQUEST) else: if not res.successful(): if res.status == 'PENDING': - data['message'] = 'Task is either waiting, does not exist in this context or has been removed after 24h' + data['message'] = 'Task is either waiting, ' \ + 'does not exist in this context or has been removed after 24h' else: data['message'] = res.traceback else: diff --git a/backend/substrapp/views/testtuple.py b/backend/substrapp/views/testtuple.py new file mode 100644 index 000000000..6cff7cf18 --- /dev/null +++ b/backend/substrapp/views/testtuple.py @@ -0,0 +1,109 @@ +import logging + +from rest_framework import mixins, status +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.serializers import LedgerTestTupleSerializer +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError, LedgerConflict +from substrapp.views.filters_utils import filter_list +from substrapp.views.utils import validate_pk, get_success_create_code, LedgerException + + +class TestTupleViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + serializer_class = LedgerTestTupleSerializer + ledger_query_call = 'queryTesttuple' + + def get_queryset(self): + return [] + + def perform_create(self, serializer): + return serializer.save() + + def commit(self, serializer, pkhash): + # create on ledger + try: + data = serializer.create(serializer.validated_data) + except LedgerError as e: + raise LedgerException({'message': str(e.msg), 'pkhash': pkhash}, e.status) + else: + return data + + def _create(self, request): + data = { + 'traintuple_key': request.data.get('traintuple_key'), + 'data_manager_key': request.data.get('data_manager_key', ''), + 'test_data_sample_keys': request.data.getlist('test_data_sample_keys'), + 'tag': request.data.get('tag', '') + } + + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) + + # Get traintuple pkhash to handle 408 timeout in invoke_ledger + args = serializer.get_args(serializer.validated_data) + + try: + data = query_ledger(fcn='createTesttuple', args=args) + except LedgerConflict as e: + raise LedgerException({'message': str(e.msg), 'pkhash': e.pkhash}, e.status) + except LedgerError as e: + raise LedgerException({'message': str(e.msg)}, e.status) + else: + pkhash = data.get('key') + return self.commit(serializer, pkhash) + + def create(self, request, *args, **kwargs): + try: + data = self._create(request) + except LedgerException as e: + return Response(e.data, status=e.st) + else: + headers = self.get_success_headers(data) + st = get_success_create_code() + return Response(data, status=st, headers=headers) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryTesttuples', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + testtuple_list = [data] + + query_params = request.query_params.get('search', None) + if query_params is not None: + try: + testtuple_list = filter_list( + object_type='testtuple', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + return Response(testtuple_list, status=status.HTTP_200_OK) + + def _retrieve(self, pk): + validate_pk(pk) + return get_object_from_ledger(pk, self.ledger_query_call) + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(pk) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + return Response({'message': str(e)}, status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) diff --git a/backend/substrapp/views/traintuple.py b/backend/substrapp/views/traintuple.py new file mode 100644 index 000000000..e637448d4 --- /dev/null +++ b/backend/substrapp/views/traintuple.py @@ -0,0 +1,114 @@ +import logging + +from rest_framework import mixins, status +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from substrapp.serializers import LedgerTrainTupleSerializer +from substrapp.ledger_utils import query_ledger, get_object_from_ledger, LedgerError, LedgerConflict +from substrapp.views.filters_utils import filter_list +from substrapp.views.utils import validate_pk, get_success_create_code, LedgerException + + +class TrainTupleViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + serializer_class = LedgerTrainTupleSerializer + ledger_query_call = 'queryTraintuple' + + def get_queryset(self): + return [] + + def perform_create(self, serializer): + return serializer.save() + + def commit(self, serializer, pkhash): + # create on ledger + try: + data = serializer.create(serializer.validated_data) + except LedgerError as e: + raise LedgerException({'message': str(e.msg), 'pkhash': pkhash}, e.status) + else: + return data + + def _create(self, request): + data = { + 'algo_key': request.data.get('algo_key'), + 'data_manager_key': request.data.get('data_manager_key'), + 'objective_key': request.data.get('objective_key'), + 'rank': request.data.get('rank'), + 'compute_plan_id': request.data.get('compute_plan_id', ''), + 'in_models_keys': request.data.getlist('in_models_keys'), + # list of train data keys (which are stored in the train worker node) + 'train_data_sample_keys': request.data.getlist('train_data_sample_keys'), + 'tag': request.data.get('tag', '') + } + + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) + + # Get traintuple pkhash to handle 408 timeout in invoke_ledger + args = serializer.get_args(serializer.validated_data) + + try: + data = query_ledger(fcn='createTraintuple', args=args) + except LedgerConflict as e: + raise LedgerException({'message': str(e.msg), 'pkhash': e.pkhash}, e.status) + except LedgerError as e: + raise LedgerException({'message': str(e.msg)}, e.status) + else: + pkhash = data.get('key') + return self.commit(serializer, pkhash) + + def create(self, request, *args, **kwargs): + try: + data = self._create(request) + except LedgerException as e: + return Response(e.data, status=e.st) + else: + headers = self.get_success_headers(data) + st = get_success_create_code() + return Response(data, status=st, headers=headers) + + def list(self, request, *args, **kwargs): + try: + data = query_ledger(fcn='queryTraintuples', args=[]) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + traintuple_list = [data] + + query_params = request.query_params.get('search', None) + if query_params is not None: + try: + traintuple_list = filter_list( + object_type='traintuple', + data=data, + query_params=query_params) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + logging.exception(e) + return Response( + {'message': f'Malformed search filters {query_params}'}, + status=status.HTTP_400_BAD_REQUEST) + + return Response(traintuple_list, status=status.HTTP_200_OK) + + def _retrieve(self, pk): + validate_pk(pk) + return get_object_from_ledger(pk, self.ledger_query_call) + + def retrieve(self, request, *args, **kwargs): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + data = self._retrieve(pk) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + except Exception as e: + return Response({'message': str(e)}, status.HTTP_400_BAD_REQUEST) + else: + return Response(data, status=status.HTTP_200_OK) diff --git a/backend/substrapp/views/utils.py b/backend/substrapp/views/utils.py new file mode 100644 index 000000000..77941737d --- /dev/null +++ b/backend/substrapp/views/utils.py @@ -0,0 +1,185 @@ +import os + + +from django.http import FileResponse +from rest_framework.authentication import BasicAuthentication, TokenAuthentication +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response + +from libs.sessionAuthentication import CustomSessionAuthentication +from node.authentication import NodeUser +from substrapp.ledger_utils import get_object_from_ledger, LedgerError +from substrapp.utils import NodeError, get_remote_file, get_owner, get_remote_file_content +from node.models import OutgoingNode + +from django.conf import settings +from rest_framework import status +from requests.auth import HTTPBasicAuth +from wsgiref.util import is_hop_by_hop + +from users.authentication import SecureJWTAuthentication + + +def authenticate_outgoing_request(outgoing_node_id): + try: + outgoing = OutgoingNode.objects.get(node_id=outgoing_node_id) + except OutgoingNode.DoesNotExist: + raise NodeError(f'Unauthorized to call remote node with node_id: {outgoing_node_id}') + + # to authenticate to remote node we use the current node id + # with the associated outgoing secret. + current_node_id = get_owner() + + return HTTPBasicAuth(current_node_id, outgoing.secret) + + +def get_remote_asset(url, node_id, content_hash, salt=None): + auth = authenticate_outgoing_request(node_id) + return get_remote_file_content(url, auth, content_hash, salt=salt) + + +class CustomFileResponse(FileResponse): + def set_headers(self, filelike): + super(CustomFileResponse, self).set_headers(filelike) + + self['Access-Control-Expose-Headers'] = 'Content-Disposition' + + +def node_has_process_permission(asset): + """Check if current node can process input asset.""" + permission = asset['permissions']['process'] + return permission['public'] or get_owner() in permission['authorizedIDs'] + + +class PermissionMixin(object): + authentication_classes = [ + BasicAuthentication, # for node to node + SecureJWTAuthentication, # for user from front/sdk/cli + TokenAuthentication, # for user from front/sdk/cli + CustomSessionAuthentication, # for user on drf web browsable api + ] + permission_classes = [IsAuthenticated] + + def _has_access(self, user, asset): + """Returns true if API consumer can access asset data.""" + if user.is_anonymous: # safeguard, should never happened + return False + + permission = asset['permissions']['process'] + + if isinstance(user, NodeUser): # for node + node_id = user.username + else: # for classic user, test on current msp id + node_id = get_owner() + + return permission['public'] or node_id in permission['authorizedIDs'] + + def download_file(self, request, django_field, ledger_field=None): + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + pk = self.kwargs[lookup_url_kwarg] + + try: + asset = get_object_from_ledger(pk, self.ledger_query_call) + except LedgerError as e: + return Response({'message': str(e.msg)}, status=e.status) + + if not self._has_access(request.user, asset): + return Response({'message': 'Unauthorized'}, + status=status.HTTP_403_FORBIDDEN) + + if get_owner() == asset['owner']: + obj = self.get_object() + data = getattr(obj, django_field) + response = CustomFileResponse( + open(data.path, 'rb'), + as_attachment=True, + filename=os.path.basename(data.path) + ) + else: + node_id = asset['owner'] + auth = authenticate_outgoing_request(node_id) + if not ledger_field: + ledger_field = django_field + r = get_remote_file(asset[ledger_field]['storageAddress'], auth, stream=True) + if not r.ok: + return Response({ + 'message': f'Cannot proxify asset from node {asset["owner"]}: {str(r.text)}' + }, status=r.status_code) + + response = CustomFileResponse( + streaming_content=(chunk for chunk in r.iter_content(512 * 1024)), + status=r.status_code) + + for header in r.headers: + # We don't use hop_by_hop headers since they are incompatible + # with WSGI + if not is_hop_by_hop(header): + response[header] = r.headers.get(header) + + return response + + +def find_primary_key_error(validation_error, key_name='pkhash'): + detail = validation_error.detail + + def find_unique_error(detail_dict): + for key, errors in detail_dict.items(): + if key != key_name: + continue + for error in errors: + if error.code == 'unique': + return error + + return None + + # according to the rest_framework documentation, + # validation_error.detail could be either a dict, a list or a nested + # data structure + + if isinstance(detail, dict): + return find_unique_error(detail) + elif isinstance(detail, list): + for sub_detail in detail: + if isinstance(sub_detail, dict): + unique_error = find_unique_error(sub_detail) + if unique_error is not None: + return unique_error + + return None + + +def validate_pk(pk): + if len(pk) != 64: + raise Exception(f'Wrong pk {pk}') + + try: + int(pk, 16) # test if pk is correct (hexadecimal) + except ValueError: + raise Exception(f'Wrong pk {pk}') + + +def validate_sort(sort): + if sort not in ['asc', 'desc']: + raise Exception(f"Invalid sort value (must be either 'desc' or 'asc'): {sort}") + + +class LedgerException(Exception): + def __init__(self, data, st): + self.data = data + self.st = st + super(LedgerException).__init__() + + +class ValidationException(Exception): + def __init__(self, data, pkhash, st): + self.data = data + self.pkhash = pkhash + self.st = st + super(ValidationException).__init__() + + +def get_success_create_code(): + if getattr(settings, 'LEDGER_SYNC_ENABLED'): + return status.HTTP_201_CREATED + else: + return status.HTTP_202_ACCEPTED diff --git a/backend/users/__init__.py b/backend/users/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/users/apps.py b/backend/users/apps.py new file mode 100644 index 000000000..4ce1fabc0 --- /dev/null +++ b/backend/users/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class UsersConfig(AppConfig): + name = 'users' diff --git a/backend/users/authentication.py b/backend/users/authentication.py new file mode 100644 index 000000000..16f363d29 --- /dev/null +++ b/backend/users/authentication.py @@ -0,0 +1,28 @@ +from rest_framework_simplejwt.authentication import JWTAuthentication + + +class SecureJWTAuthentication(JWTAuthentication): + + def authenticate(self, request): + if request.resolver_match.url_name in ('user-login', 'api-root'): + return None + + header = self.get_header(request) + if header is None: + return None + + raw_token = self.get_raw_token(header) + if raw_token is None: + return None + + # reconstruct token from httpOnly cookie signature + try: + signature = request.COOKIES['signature'] + except Exception: + return None + else: + raw_token = raw_token + f".{signature}".encode() + + validated_token = self.get_validated_token(raw_token) + + return self.get_user(validated_token), None diff --git a/backend/users/management/commands/__init__.py b/backend/users/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/users/management/commands/add_user.py b/backend/users/management/commands/add_user.py new file mode 100644 index 000000000..987d267f2 --- /dev/null +++ b/backend/users/management/commands/add_user.py @@ -0,0 +1,36 @@ +import secrets + +from django.contrib.auth import get_user_model +from django.contrib.auth.password_validation import validate_password +from django.core.exceptions import ValidationError +from django.core.management.base import BaseCommand +from django.db import IntegrityError + + +class Command(BaseCommand): + help = 'Add user' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.UserModel = get_user_model() + + def add_arguments(self, parser): + parser.add_argument('username') + parser.add_argument('password', nargs='?', default=secrets.token_hex(8)) + + def handle(self, *args, **options): + + username = options['username'] + password = options['password'] + + try: + validate_password(password, self.UserModel(username=username)) + except ValidationError as err: + self.stderr.write('\n'.join(err.messages)) + else: + try: + self.UserModel.objects.create_user(username=username, password=password) + except IntegrityError as e: + self.stderr.write(f'User already exists: {str(e)}') + else: + self.stdout.write(f"password: {password}") diff --git a/backend/users/serializers/__init__.py b/backend/users/serializers/__init__.py new file mode 100644 index 000000000..87ae7d6c1 --- /dev/null +++ b/backend/users/serializers/__init__.py @@ -0,0 +1,5 @@ +# encoding: utf-8 + +from .user import CustomTokenObtainPairSerializer + +__all__ = ['CustomTokenObtainPairSerializer'] diff --git a/backend/users/serializers/user.py b/backend/users/serializers/user.py new file mode 100644 index 000000000..b08c8476d --- /dev/null +++ b/backend/users/serializers/user.py @@ -0,0 +1,32 @@ +from rest_framework_simplejwt.serializers import TokenObtainSerializer +from rest_framework_simplejwt.settings import api_settings +from rest_framework_simplejwt.token_blacklist.models import OutstandingToken +from rest_framework_simplejwt.tokens import AccessToken +from rest_framework_simplejwt.utils import datetime_from_epoch + + +class CustomTokenObtainPairSerializer(TokenObtainSerializer): + def get_token(self, user): + """ + Adds this token to the outstanding token list. + """ + token = AccessToken.for_user(user) + + jti = token[api_settings.JTI_CLAIM] + exp = token['exp'] + + OutstandingToken.objects.create( + user=user, + jti=jti, + token=str(token), + created_at=token.current_time, + expires_at=datetime_from_epoch(exp), + ) + + return token + + def validate(self, attrs): + super().validate(attrs) + token = self.get_token(self.user) + + return token diff --git a/backend/users/urls.py b/backend/users/urls.py new file mode 100644 index 000000000..557c4ef17 --- /dev/null +++ b/backend/users/urls.py @@ -0,0 +1,16 @@ +""" +substrapp URL +""" + +from django.conf.urls import url, include +from rest_framework.routers import DefaultRouter + +# Create a router and register our viewsets with it. +from users.views import UserViewSet + +router = DefaultRouter() +router.register(r'user', UserViewSet, base_name='user') + +urlpatterns = [ + url(r'^', include(router.urls)), +] diff --git a/backend/users/views/__init__.py b/backend/users/views/__init__.py new file mode 100644 index 000000000..f114283ed --- /dev/null +++ b/backend/users/views/__init__.py @@ -0,0 +1,5 @@ +# encoding: utf-8 + +from .user import UserViewSet + +__all__ = ['UserViewSet'] diff --git a/backend/users/views/user.py b/backend/users/views/user.py new file mode 100644 index 000000000..a0faf7517 --- /dev/null +++ b/backend/users/views/user.py @@ -0,0 +1,81 @@ +from django.conf import settings +from django.contrib.auth.models import User + +from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.viewsets import GenericViewSet +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework_simplejwt.authentication import AUTH_HEADER_TYPES +from rest_framework_simplejwt.exceptions import TokenError, InvalidToken, AuthenticationFailed + +from users.serializers import CustomTokenObtainPairSerializer + +import tldextract + + +class UserViewSet(GenericViewSet): + queryset = User.objects.all() + serializer_class = CustomTokenObtainPairSerializer + + www_authenticate_realm = 'api' + + permission_classes = [AllowAny] + + def get_authenticate_header(self, request): + return '{0} realm="{1}"'.format( + AUTH_HEADER_TYPES[0], + self.www_authenticate_realm, + ) + + def get_host(self, request): + ext = tldextract.extract(request.get_host()) + host = ext.domain + if ext.suffix: + host += '.' + ext.suffix + + return host + + @action(methods=['post'], detail=False) + def login(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + + try: + serializer.is_valid(raise_exception=True) + except AuthenticationFailed: + return Response({'message': 'wrong username password'}, status=status.HTTP_401_UNAUTHORIZED) + except TokenError as e: + raise InvalidToken(e.args[0]) + + token = serializer.validated_data + + expires = token.current_time + token.lifetime + + tokenString = str(token) + headerPayload = '.'.join(tokenString.split('.')[0:2]) + signature = tokenString.split('.')[2] + + response = Response(token.payload, status=status.HTTP_200_OK) + + host = self.get_host(request) + + if settings.DEBUG: + response.set_cookie('header.payload', value=headerPayload, expires=expires, domain=host) + response.set_cookie('signature', value=signature, httponly=True, domain=host) + else: + response.set_cookie('header.payload', value=headerPayload, expires=expires, secure=True, domain=host) + response.set_cookie('signature', value=signature, httponly=True, secure=True, domain=host) + return response + + @action(detail=False) + def logout(self, request, *args, **kwargs): + response = Response({}, status=status.HTTP_200_OK) + + host = self.get_host(request) + if settings.DEBUG: + response.set_cookie('header.payload', value='', domain=host) + response.set_cookie('signature', value='', httponly=True, domain=host) + else: + response.set_cookie('header.payload', value='', secure=True, domain=host) + response.set_cookie('signature', value='', httponly=True, secure=True, domain=host) + return response diff --git a/bootstrap.sh b/bootstrap.sh deleted file mode 100755 index c622836bc..000000000 --- a/bootstrap.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/bin/bash - -BASEDIR=$(dirname "$0") - -# if version not passed in, default to latest released version -export VERSION=1.4.1 -# if ca version not passed in, default to latest released version -export CA_VERSION=$VERSION -# current version of thirdparty images (couchdb, kafka and zookeeper) released -export THIRDPARTY_IMAGE_VERSION=0.4.15 -export ARCH=$(echo "$(uname -s|tr '[:upper:]' '[:lower:]'|sed 's/mingw64_nt.*/windows/')-$(uname -m | sed 's/x86_64/amd64/g')") -export MARCH=$(uname -m) - -# starting with 1.2.0, multi-arch images will be default -: ${CA_TAG:="$CA_VERSION"} -: ${FABRIC_TAG:="$VERSION"} - -BINARY_FILE=hyperledger-fabric-${ARCH}-${VERSION}.tar.gz -CA_BINARY_FILE=hyperledger-fabric-ca-${ARCH}-${CA_VERSION}.tar.gz - -# Incrementally downloads the .tar.gz file locally first, only decompressing it -# after the download is complete. This is slower than binaryDownload() but -# allows the download to be resumed. -binaryIncrementalDownload() { - local BINARY_FILE=$1 - local URL=$2 - curl -f -s -C - ${URL} -o ${BINARY_FILE} || rc=$? - # Due to limitations in the current Nexus repo: - # curl returns 33 when there's a resume attempt with no more bytes to download - # curl returns 2 after finishing a resumed download - # with -f curl returns 22 on a 404 - if [ "$rc" = 22 ]; then - # looks like the requested file doesn't actually exist so stop here - return 22 - fi - if [ -z "$rc" ] || [ $rc -eq 33 ] || [ $rc -eq 2 ]; then - # The checksum validates that RC 33 or 2 are not real failures - echo "==> File downloaded. Verifying the md5sum..." - localMd5sum=$(md5sum ${BINARY_FILE} | awk '{print $1}') - remoteMd5sum=$(curl -s ${URL}.md5) - if [ "$localMd5sum" == "$remoteMd5sum" ]; then - echo "==> Extracting ${BINARY_FILE}..." - tar xzf ./${BINARY_FILE} --overwrite - echo "==> Done." - rm -f ${BINARY_FILE} ${BINARY_FILE}.md5 - else - echo "Download failed: the local md5sum is different from the remote md5sum. Please try again." - rm -f ${BINARY_FILE} ${BINARY_FILE}.md5 - exit 1 - fi - else - echo "Failure downloading binaries (curl RC=$rc). Please try again and the download will resume from where it stopped." - exit 1 - fi -} - -# This will attempt to download the .tar.gz all at once, but will trigger the -# binaryIncrementalDownload() function upon a failure, allowing for resume -# if there are network failures. -binaryDownload() { - local BINARY_FILE=$1 - local URL=$2 - echo "===> Downloading: " ${URL} - # Check if a previous failure occurred and the file was partially downloaded - if [ -e ${BINARY_FILE} ]; then - echo "==> Partial binary file found. Resuming download..." - binaryIncrementalDownload ${BINARY_FILE} ${URL} - else - curl ${URL} | tar xz || rc=$? - if [ ! -z "$rc" ]; then - echo "==> There was an error downloading the binary file. Switching to incremental download." - echo "==> Downloading file..." - binaryIncrementalDownload ${BINARY_FILE} ${URL} - else - echo "==> Done." - fi - fi -} - -binariesInstall() { - echo "===> Downloading version ${FABRIC_TAG} platform specific fabric binaries" - binaryDownload ${BINARY_FILE} https://nexus.hyperledger.org/content/repositories/releases/org/hyperledger/fabric/hyperledger-fabric/${ARCH}-${VERSION}/${BINARY_FILE} - if [ $? -eq 22 ]; then - echo - echo "------> ${FABRIC_TAG} platform specific fabric binary is not available to download <----" - echo - fi - - echo "===> Downloading version ${CA_TAG} platform specific fabric-ca-client binary" - binaryDownload ${CA_BINARY_FILE} https://nexus.hyperledger.org/content/repositories/releases/org/hyperledger/fabric-ca/hyperledger-fabric-ca/${ARCH}-${CA_VERSION}/${CA_BINARY_FILE} - if [ $? -eq 22 ]; then - echo - echo "------> ${CA_TAG} fabric-ca-client binary is not available to download (Available from 1.1.0-rc1) <----" - echo - fi -} - -binariesInstall - -# remove config directory -rm -r config diff --git a/charts/substra-backend/.gitignore b/charts/substra-backend/.gitignore new file mode 100644 index 000000000..ee3892e87 --- /dev/null +++ b/charts/substra-backend/.gitignore @@ -0,0 +1 @@ +charts/ diff --git a/charts/substra-backend/.helmignore b/charts/substra-backend/.helmignore new file mode 100644 index 000000000..6b8710a71 --- /dev/null +++ b/charts/substra-backend/.helmignore @@ -0,0 +1 @@ +.git diff --git a/charts/substra-backend/Chart.yaml b/charts/substra-backend/Chart.yaml new file mode 100644 index 000000000..e2b51d4eb --- /dev/null +++ b/charts/substra-backend/Chart.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +name: substra-backend +home: https://substra.org/ +version: 1.0.0-alpha.11 +description: Main package for Substra +icon: https://avatars1.githubusercontent.com/u/38098422?s=200&v=4 +sources: + - https://github.com/SubstraFoudation/substra-backend +maintainers: + - name: ClementGautier + email: clement@gautier.im diff --git a/charts/substra-backend/README.md b/charts/substra-backend/README.md new file mode 100644 index 000000000..374ab50cc --- /dev/null +++ b/charts/substra-backend/README.md @@ -0,0 +1,22 @@ +# Main deployment package of Substra + +## Requirements + +Having a Kubernetes cluster working with Helm initialized. You can do thant locally by installing Minikube and grabbing Helm binary from github. +Then simply launch your cluster using `minikube start` and configure helm with `helm init`. + +You will also need the Hyperledger Fabric network setup on the cluster. +Look at the corresponding chart for that (chart-hlf-k8s) + +You will also need a postgresql instance on the cluster, it should already be the case if you install the network first. + +## Install the package +``` +helm install --name hlf-k8s owkin/hlf-k8s +helm install --name substra owkin/substra +``` + +### Cleanup +``` +helm delete --purge hlf-k8s substra +``` diff --git a/charts/substra-backend/requirements.lock b/charts/substra-backend/requirements.lock new file mode 100644 index 000000000..17d707909 --- /dev/null +++ b/charts/substra-backend/requirements.lock @@ -0,0 +1,9 @@ +dependencies: +- name: rabbitmq + repository: https://kubernetes-charts.storage.googleapis.com/ + version: 6.2.6 +- name: postgresql + repository: https://kubernetes-charts.storage.googleapis.com/ + version: 6.2.1 +digest: sha256:f2534148823dfb50af552d37e81afac8e298073730c21bd14d3f471356b2aa09 +generated: "2019-09-05T11:09:00.717234726+02:00" diff --git a/charts/substra-backend/requirements.yaml b/charts/substra-backend/requirements.yaml new file mode 100644 index 000000000..b613c0a93 --- /dev/null +++ b/charts/substra-backend/requirements.yaml @@ -0,0 +1,9 @@ +dependencies: + - name: rabbitmq + repository: https://kubernetes-charts.storage.googleapis.com/ + condition: rabbitmq.enabled + version: ~6.2.5 + - name: postgresql + repository: https://kubernetes-charts.storage.googleapis.com/ + version: ~6.2.0 + condition: postgresql.enabled diff --git a/charts/substra-backend/templates/_helpers.tpl b/charts/substra-backend/templates/_helpers.tpl new file mode 100644 index 000000000..3868c7f06 --- /dev/null +++ b/charts/substra-backend/templates/_helpers.tpl @@ -0,0 +1,16 @@ +{{/* vim: set filetype=mustache: */}} +{{/* +Expand the name of the chart. +*/}} +{{- define "substra.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}} +{{- end -}} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +*/}} +{{- define "substra.fullname" -}} +{{- $name := default .Chart.Name .Values.nameOverride -}} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" -}} +{{- end -}} diff --git a/charts/substra-backend/templates/configmap-backend.yaml b/charts/substra-backend/templates/configmap-backend.yaml new file mode 100644 index 000000000..1fe400cf6 --- /dev/null +++ b/charts/substra-backend/templates/configmap-backend.yaml @@ -0,0 +1,48 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ template "substra.fullname" . }}-backend +data: + conf.json: | + { + "name": "{{ .Values.organization.name }}", + "signcert": "/var/hyperledger/msp/signcerts/cert.pem", + "core_peer_mspconfigpath": "/var/hyperledger/msp", + "channel_name": "{{ .Values.channel }}", + "chaincode_name": "{{ .Values.chaincode.name }}", + "chaincode_version": "{{ .Values.chaincode.version }}", + "client": { + "name": "{{ .Values.user.name }}", + "org": "{{ .Values.organization.name }}", + "state_store": "/tmp/hfc-cvs", + "key_path": "/var/hyperledger/msp/keystore/*", + "cert_path": "/var/hyperledger/msp/signcerts/cert.pem", + "msp_id": "{{ .Values.peer.mspID }}" + }, + "peer": { + "name": "peer", + "host": "{{ .Values.peer.host }}", + "port": { + "internal": {{ .Values.peer.port }}, + "external": {{ .Values.peer.port }} + }, + "docker_core_dir": "/var/hyperledger/fabric_cfg", + "tlsCACerts": "/var/hyperledger/admin_msp/cacerts/cacert.pem", + "clientKey": "/var/hyperledger/tls/client/pair/tls.key", + "clientCert": "/var/hyperledger/tls/client/pair/tls.crt", + "grpcOptions": { + "grpc-max-send-message-length": 15, + "grpc.ssl_target_name_override": "{{ .Values.peer.host }}" + } + }, + "orderer": { + "name": "{{ .Values.orderer.name }}", + "host": "{{ .Values.orderer.host }}", + "port": {{ .Values.orderer.port }}, + "ca": "/var/hyperledger/tls/ord/cert/cacert.pem", + "grpcOptions": { + "grpc-max-send-message-length": 15, + "grpc.ssl_target_name_override": "{{ .Values.orderer.host }}" + } + } + } diff --git a/charts/substra-backend/templates/daemonset-nvidia-driver-cos.yaml b/charts/substra-backend/templates/daemonset-nvidia-driver-cos.yaml new file mode 100644 index 000000000..653775a32 --- /dev/null +++ b/charts/substra-backend/templates/daemonset-nvidia-driver-cos.yaml @@ -0,0 +1,99 @@ +{{- if and (.Values.gpu.enabled) (eq .Values.gpu.platform "cos") }} +# https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded.yaml +# Copyright 2017 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The Dockerfile and other source for this daemonset are in +# https://github.com/GoogleCloudPlatform/cos-gpu-installer +# +# This is the same as ../../daemonset.yaml except that it assumes that the +# docker image is present on the node instead of downloading from GCR. This +# allows easier upgrades because GKE can preload the correct image on the +# node and the daemonset can just use that image. +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: nvidia-driver-installer + namespace: kube-system + labels: + k8s-app: nvidia-driver-installer +spec: + selector: + matchLabels: + k8s-app: nvidia-driver-installer + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + name: nvidia-driver-installer + k8s-app: nvidia-driver-installer + spec: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-accelerator + operator: Exists + tolerations: + - operator: "Exists" + hostNetwork: true + hostPID: true + volumes: + - name: dev + hostPath: + path: /dev + - name: vulkan-icd-mount + hostPath: + path: /home/kubernetes/bin/nvidia/vulkan/icd.d + - name: nvidia-install-dir-host + hostPath: + path: /home/kubernetes/bin/nvidia + - name: root-mount + hostPath: + path: / + initContainers: + - image: "cos-nvidia-installer:fixed" + imagePullPolicy: Never + name: nvidia-driver-installer + resources: + requests: + cpu: 0.15 + securityContext: + privileged: true + env: + - name: NVIDIA_INSTALL_DIR_HOST + value: /home/kubernetes/bin/nvidia + - name: NVIDIA_INSTALL_DIR_CONTAINER + value: /usr/local/nvidia + - name: VULKAN_ICD_DIR_HOST + value: /home/kubernetes/bin/nvidia/vulkan/icd.d + - name: VULKAN_ICD_DIR_CONTAINER + value: /etc/vulkan/icd.d + - name: ROOT_MOUNT_DIR + value: /root + volumeMounts: + - name: nvidia-install-dir-host + mountPath: /usr/local/nvidia + - name: vulkan-icd-mount + mountPath: /etc/vulkan/icd.d + - name: dev + mountPath: /dev + - name: root-mount + mountPath: /root + containers: + - image: "gcr.io/google-containers/pause:2.0" + name: pause +{{- end }} diff --git a/charts/substra-backend/templates/daemonset-nvidia-driver-ubuntu.yaml b/charts/substra-backend/templates/daemonset-nvidia-driver-ubuntu.yaml new file mode 100644 index 000000000..6c24efd8b --- /dev/null +++ b/charts/substra-backend/templates/daemonset-nvidia-driver-ubuntu.yaml @@ -0,0 +1,72 @@ +{{- if and (.Values.gpu.enabled) (eq .Values.gpu.platform "ubuntu") }} +# https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/ubuntu/daemonset-preloaded.yaml +# Copyright 2017 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: nvidia-driver-installer + namespace: kube-system + labels: + k8s-app: nvidia-driver-installer +spec: + selector: + matchLabels: + k8s-app: nvidia-driver-installer + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + name: nvidia-driver-installer + k8s-app: nvidia-driver-installer + spec: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-accelerator + operator: Exists + tolerations: + - operator: "Exists" + volumes: + - name: dev + hostPath: + path: /dev + - name: boot + hostPath: + path: /boot + - name: root-mount + hostPath: + path: / + initContainers: + - image: gke-nvidia-installer:fixed + name: nvidia-driver-installer + resources: + requests: + cpu: 0.15 + securityContext: + privileged: true + volumeMounts: + - name: boot + mountPath: /boot + - name: dev + mountPath: /dev + - name: root-mount + mountPath: /root + containers: + - image: "gcr.io/google-containers/pause:2.0" + name: pause +{{- end }} diff --git a/charts/substra-backend/templates/daemonset-nvidia-plugin.yaml b/charts/substra-backend/templates/daemonset-nvidia-plugin.yaml new file mode 100644 index 000000000..cffcb0624 --- /dev/null +++ b/charts/substra-backend/templates/daemonset-nvidia-plugin.yaml @@ -0,0 +1,43 @@ +{{- if .Values.gpu.enabled }} +# https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/1.0.0-beta/nvidia-device-plugin.yml +apiVersion: extensions/v1beta1 +kind: DaemonSet +metadata: + name: nvidia-device-plugin-daemonset + namespace: kube-system +spec: + updateStrategy: + type: RollingUpdate + template: + metadata: + # Mark this pod as a critical add-on; when enabled, the critical add-on scheduler + # reserves resources for critical add-on pods so that they can be rescheduled after + # a failure. This annotation works in tandem with the toleration below. + annotations: + scheduler.alpha.kubernetes.io/critical-pod: "" + labels: + name: nvidia-device-plugin-ds + spec: + tolerations: + # Allow this pod to be rescheduled while the node is in "critical add-ons only" mode. + # This, along with the annotation above marks this pod as a critical add-on. + - key: CriticalAddonsOnly + operator: Exists + - key: nvidia.com/gpu + operator: Exists + effect: NoSchedule + containers: + - image: nvidia/k8s-device-plugin:1.0.0-beta + name: nvidia-device-plugin-ctr + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] + volumeMounts: + - name: device-plugin + mountPath: /var/lib/kubelet/device-plugins + volumes: + - name: device-plugin + hostPath: + path: /var/lib/kubelet/device-plugins +{{- end }} diff --git a/charts/substra-backend/templates/daemonset-pull-docker-images.yaml b/charts/substra-backend/templates/daemonset-pull-docker-images.yaml new file mode 100644 index 000000000..36fb9cd90 --- /dev/null +++ b/charts/substra-backend/templates/daemonset-pull-docker-images.yaml @@ -0,0 +1,76 @@ +{{- if .Values.docker.pullImages }} +{{- if .Values.docker.config }} +--- +apiVersion: v1 +kind: Secret +type: kubernetes.io/dockerconfigjson +data: + .dockerconfigjson: {{ .Values.docker.config }} +metadata: + name: {{ template "substra.fullname" . }}-pull-docker-images-creds + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-pull-docker-images-creds + app.kubernetes.io/part-of: {{ template "substra.name" . }} +--- +{{- end }} +apiVersion: extensions/v1beta1 +kind: DaemonSet +metadata: + name: {{ template "substra.fullname" . }}-pull-docker-images + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-pull-docker-images + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + selector: + matchLabels: + name: {{ template "substra.fullname" . }}-pull-docker-images + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + name: {{ template "substra.fullname" . }}-pull-docker-images + spec: + initContainers: + - image: docker + name: init + resources: + requests: + cpu: 0.15 + securityContext: + privileged: true + command: ["sh", "-c"] + args: + - | + {{- range .Values.docker.pullImages }} + docker pull {{ . }} + {{- end }} + volumeMounts: + - name: dockersock + mountPath: "/var/run/docker.sock" + {{- if .Values.docker.config }} + - name: dockerconfig + mountPath: "/root/.docker" + {{- end }} + containers: + - image: "gcr.io/google-containers/pause:2.0" + name: pause + volumes: + - name: dockersock + hostPath: + path: {{ .Values.docker.socket }} + {{- if .Values.docker.config }} + - name: dockerconfig + secret: + secretName: {{ template "substra.fullname" . }}-pull-docker-images-creds + items: + - key: .dockerconfigjson + path: config.json + {{- end }} +{{- end }} diff --git a/charts/substra-backend/templates/deployment-backend.yaml b/charts/substra-backend/templates/deployment-backend.yaml new file mode 100644 index 000000000..eef09920d --- /dev/null +++ b/charts/substra-backend/templates/deployment-backend.yaml @@ -0,0 +1,191 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ template "substra.fullname" . }}-backend + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-backend + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + replicas: {{ .Values.backend.replicaCount }} + selector: + matchLabels: + app.kubernetes.io/name: {{ template "substra.name" . }}-backend + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-backend + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + {{- if or .Values.pullSecretsInline .Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := .Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range .Values.backend.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: {{ template "substra.name" . }} + image: "{{ .Values.backend.image.repository }}:{{ .Values.backend.image.tag }}" + {{- if .Values.backend.image.pullPolicy }} + imagePullPolicy: "{{ .Values.backend.image.pullPolicy }}" + {{- end }} + command: ["/bin/bash"] + {{- if eq .Values.backend.settings "prod" }} + args: ["-c", "python manage.py migrate && python3 manage.py collectstatic --noinput && uwsgi --http :8000 --module backend.wsgi --static-map /static=/usr/src/app/backend/statics --master --processes 4 --threads 2 --need-app --env DJANGO_SETTINGS_MODULE=backend.settings.server.{{ .Values.backend.settings }} "] + {{- else }} + args: ["-c", "python manage.py migrate && DJANGO_SETTINGS_MODULE=backend.settings.server.{{ .Values.backend.settings }} python3 manage.py runserver --noreload 0.0.0.0:8000"] + {{- end }} + env: + - name: ORG + value: {{ .Values.organization.name }} + - name: BACKEND_ORG + value: {{ .Values.organization.name }} + - name: BACKEND_{{ .Values.organization.name | upper }}_DB_NAME + value: {{ .Values.postgresql.postgresqlDatabase }} + - name: BACKEND_DB_USER + value: {{ .Values.postgresql.postgresqlUsername }} + - name: BACKEND_DB_PWD + value: {{ .Values.postgresql.postgresqlPassword }} + - name: DATABASE_HOST + value: {{ .Release.Name }}-postgresql + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ .Values.backend.settings }} + - name: FABRIC_CFG_PATH + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ .Values.peer.host }}:{{ .Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: DEFAULT_DOMAIN + value: "{{ .Values.backend.defaultDomain }}" + - name: CELERY_BROKER_URL + value: "amqp://{{ .Values.rabbitmq.rabbitmq.username }}:{{ .Values.rabbitmq.rabbitmq.password }}@{{ .Release.Name }}-{{ .Values.rabbitmq.host }}:{{ .Values.rabbitmq.port }}//" + - name: BACKEND_DEFAULT_PORT + value: {{ .Values.backend.service.port | quote}} + - name: BACKEND_PEER_PORT + value: "internal" + - name: LEDGER_CONFIG_FILE + value: /conf/{{ .Values.organization.name }}/substra-backend/conf.json + - name: PYTHONUNBUFFERED + value: "1" + - name: MEDIA_ROOT + value: {{ .Values.persistence.hostPath }}/medias/ + ports: + - name: http + containerPort: {{ .Values.backend.service.port }} + protocol: TCP + volumeMounts: + - mountPath: {{ .Values.persistence.hostPath }} + name: data + - mountPath: /conf/{{ .Values.organization.name }}/substra-backend + name: config + readOnly: true + - mountPath: /var/hyperledger/fabric_cfg + name: fabric + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + livenessProbe: + httpGet: + path: /liveness + port: http + httpHeaders: + - name: Accept + value: "text/html;version=0.0, */*;version=0.0" + initialDelaySeconds: 60 + timeoutSeconds: 5 + failureThreshold: 6 + readinessProbe: + httpGet: + path: /readiness + port: http + httpHeaders: + - name: Accept + value: "text/html;version=0.0, */*;version=0.0" + initialDelaySeconds: 10 + timeoutSeconds: 2 + periodSeconds: 5 + resources: + {{- toYaml .Values.backend.resources | nindent 12 }} + volumes: + - name: data + persistentVolumeClaim: + claimName: {{ include "substra.fullname" . }} + - name: config + configMap: + name: {{ include "substra.fullname" . }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ $.Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ $.Values.secrets.key }} + - name: cacert + secret: + secretName: {{ $.Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ $.Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ $.Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ $.Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ $.Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ $.Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + {{- with .Values.backend.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.backend.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.backend.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/substra-backend/templates/deployment-celerybeat.yaml b/charts/substra-backend/templates/deployment-celerybeat.yaml new file mode 100644 index 000000000..3f5715ac7 --- /dev/null +++ b/charts/substra-backend/templates/deployment-celerybeat.yaml @@ -0,0 +1,60 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ template "substra.fullname" . }}-celerybeat + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-celerybeat + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + replicas: {{ .Values.celerybeat.replicaCount }} + selector: + matchLabels: + app.kubernetes.io/name: {{ template "substra.name" . }}-celerybeat + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-celerybeat + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + {{- if or .Values.pullSecretsInline .Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := .Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range .Values.backend.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: celerybeat + image: "{{ .Values.celerybeat.image.repository }}:{{ .Values.celerybeat.image.tag }}" + {{- if .Values.celerybeat.image.pullPolicy }} + imagePullPolicy: "{{ .Values.celerybeat.image.pullPolicy }}" + {{- end }} + command: ["celery"] + args: ["-A", "backend", "beat", "-l", "debug"] + env: + - name: CELERY_BROKER_URL + value: "amqp://{{ .Values.rabbitmq.rabbitmq.username }}:{{ .Values.rabbitmq.rabbitmq.password }}@{{ .Release.Name }}-{{ .Values.rabbitmq.host }}:{{ .Values.rabbitmq.port }}//" + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.common + - name: PYTHONUNBUFFERED + value: "1" + resources: + {{- toYaml .Values.celerybeat.resources | nindent 12 }} + {{- with .Values.celerybeat.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celerybeat.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celerybeat.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/substra-backend/templates/deployment-flower.yaml b/charts/substra-backend/templates/deployment-flower.yaml new file mode 100644 index 000000000..30c10c97c --- /dev/null +++ b/charts/substra-backend/templates/deployment-flower.yaml @@ -0,0 +1,62 @@ +{{- if .Values.flower.enabled -}} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ template "substra.fullname" . }}-flower + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-flower + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + replicas: {{ .Values.flower.replicaCount }} + selector: + matchLabels: + app.kubernetes.io/name: {{ template "substra.name" . }}-flower + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-flower + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + {{- if or .Values.pullSecretsInline .Values.flower.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := .Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range .Values.flower.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: flower + image: "{{ .Values.flower.image.repository }}:{{ .Values.flower.image.tag }}" + {{- if .Values.flower.image.pullPolicy }} + imagePullPolicy: "{{ .Values.flower.image.pullPolicy }}" + {{- end }} + command: ["celery"] + args: ["flower", "-A", "backend"] + env: + - name: CELERY_BROKER_URL + value: "amqp://{{ .Values.rabbitmq.rabbitmq.username }}:{{ .Values.rabbitmq.rabbitmq.password }}@{{ .Release.Name }}-{{ .Values.rabbitmq.host }}:{{ .Values.rabbitmq.port }}//" + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.common + - name: PYTHONUNBUFFERED + value: "1" + resources: + {{- toYaml .Values.flower.resources | nindent 12 }} + {{- with .Values.flower.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.flower.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.flower.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end }} diff --git a/charts/substra-backend/templates/deployment-scheduler.yaml b/charts/substra-backend/templates/deployment-scheduler.yaml new file mode 100644 index 000000000..bb521a160 --- /dev/null +++ b/charts/substra-backend/templates/deployment-scheduler.yaml @@ -0,0 +1,143 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ template "substra.fullname" . }}-scheduler + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-scheduler + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + replicas: {{ .Values.celeryworker.replicaCount }} + selector: + matchLabels: + app.kubernetes.io/name: {{ template "substra.name" . }}-scheduler + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-scheduler + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + {{- if or .Values.pullSecretsInline .Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := .Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range .Values.backend.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: scheduler + image: "{{ .Values.celeryworker.image.repository }}:{{ .Values.celeryworker.image.tag }}" + {{- if .Values.celeryworker.image.pullPolicy }} + imagePullPolicy: "{{ .Values.celeryworker.image.pullPolicy }}" + {{- end }} + command: ["celery"] + args: ["-A", "backend", "worker", "-l", "info", "-n", "{{ .Values.organization.name }}", "-Q", "{{ .Values.organization.name }},scheduler,celery", "--hostname", "{{ .Values.organization.name }}.scheduler"] + env: + - name: ORG + value: {{ .Values.organization.name }} + - name: BACKEND_ORG + value: {{ .Values.organization.name }} + - name: BACKEND_DEFAULT_PORT + value: "8000" + - name: CELERY_BROKER_URL + value: "amqp://{{ .Values.rabbitmq.rabbitmq.username }}:{{ .Values.rabbitmq.rabbitmq.password }}@{{ .Release.Name }}-{{ .Values.rabbitmq.host }}:{{ .Values.rabbitmq.port }}//" + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ .Values.backend.settings }} + - name: PYTHONUNBUFFERED + value: "1" + - name: DATABASE_HOST + value: {{ .Release.Name }}-postgresql + - name: FABRIC_CFG_PATH_ENV + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ .Values.peer.host }}:{{ .Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: LEDGER_CONFIG_FILE + value: /conf/{{ .Values.organization.name }}/substra-backend/conf.json + volumeMounts: + - mountPath: /conf/{{ .Values.organization.name }}/substra-backend + name: config + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + resources: + {{- toYaml .Values.celeryworker.resources | nindent 12 }} + volumes: + - name: config + configMap: + name: {{ include "substra.fullname" . }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ $.Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ $.Values.secrets.key }} + - name: cacert + secret: + secretName: {{ $.Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ $.Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ $.Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ $.Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ $.Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ $.Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + {{- with .Values.celeryworker.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celeryworker.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celeryworker.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/substra-backend/templates/deployment-worker.yaml b/charts/substra-backend/templates/deployment-worker.yaml new file mode 100644 index 000000000..3a920dded --- /dev/null +++ b/charts/substra-backend/templates/deployment-worker.yaml @@ -0,0 +1,163 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ template "substra.fullname" . }}-worker + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" . }}-worker + app.kubernetes.io/part-of: {{ template "substra.name" . }} +spec: + replicas: {{ .Values.replicaCount }} + selector: + matchLabels: + app.kubernetes.io/name: {{ template "substra.name" . }}-worker + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-worker + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + {{- if or .Values.pullSecretsInline .Values.celeryworker.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := .Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range .Values.celeryworker.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: worker + image: "{{ .Values.celeryworker.image.repository }}:{{ .Values.celeryworker.image.tag }}" + {{- if .Values.celeryworker.image.pullPolicy }} + imagePullPolicy: "{{ .Values.celeryworker.image.pullPolicy }}" + {{- end }} + command: ["celery"] + args: ["-A", "backend", "worker", "-E", "-l", "info", "-n", "{{ .Values.organization.name }}", "-Q", "{{ .Values.organization.name }},{{ .Values.organization.name }}.worker,celery", "--hostname", "{{ .Values.organization.name }}.worker"] + env: + - name: ORG + value: {{ .Values.organization.name }} + - name: BACKEND_ORG + value: {{ .Values.organization.name }} + - name: BACKEND_DEFAULT_PORT + value: "8000" + - name: CELERY_BROKER_URL + value: "amqp://{{ .Values.rabbitmq.rabbitmq.username }}:{{ .Values.rabbitmq.rabbitmq.password }}@{{ .Release.Name }}-{{ .Values.rabbitmq.host }}:{{ .Values.rabbitmq.port }}//" + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ .Values.backend.settings }} + - name: PYTHONUNBUFFERED + value: "1" + - name: DEFAULT_DOMAIN + value: "{{ .Values.backend.defaultDomain }}" + - name: BACKEND_{{ .Values.organization.name | upper }}_DB_NAME + value: {{ .Values.postgresql.postgresqlDatabase }} + - name: BACKEND_DB_USER + value: {{ .Values.postgresql.postgresqlUsername }} + - name: BACKEND_DB_PWD + value: {{ .Values.postgresql.postgresqlPassword }} + - name: DATABASE_HOST + value: {{ .Release.Name }}-postgresql + - name: FABRIC_CFG_PATH_ENV + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ .Values.peer.host }}:{{ .Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: MEDIA_ROOT + value: {{ .Values.persistence.hostPath }}/medias/ + - name: LEDGER_CONFIG_FILE + value: /conf/{{ .Values.organization.name }}/substra-backend/conf.json + volumeMounts: + - mountPath: /var/run/docker.sock + name: dockersocket + - mountPath: {{ .Values.persistence.hostPath }} + name: data + - mountPath: /conf/{{ .Values.organization.name }}/substra-backend + name: config + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + resources: + {{- toYaml .Values.celeryworker.resources | nindent 12 }} + volumes: + - name: dockersocket + hostPath: + path: /var/run/docker.sock + - name: data + persistentVolumeClaim: + claimName: {{ include "substra.fullname" . }} + - name: config + configMap: + name: {{ include "substra.fullname" . }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ .Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ .Values.secrets.key }} + - name: cacert + secret: + secretName: {{ .Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ .Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ .Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ .Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ .Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ .Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ .Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ .Values.secrets.tlsRootCert }} + {{- with .Values.celeryworker.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celeryworker.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.celeryworker.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/substra-backend/templates/ingress-backend.yaml b/charts/substra-backend/templates/ingress-backend.yaml new file mode 100644 index 000000000..e20669183 --- /dev/null +++ b/charts/substra-backend/templates/ingress-backend.yaml @@ -0,0 +1,38 @@ +{{- if .Values.backend.ingress.enabled -}} +apiVersion: extensions/v1beta1 +kind: Ingress +metadata: + name: {{ template "substra.fullname" . }}-backend + labels: + app.kubernetes.io/name: {{ template "substra.fullname" . }}-backend + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + {{- with .Values.backend.ingress.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: +{{- if .Values.backend.ingress.tls }} + tls: + {{- range .Values.backend.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} +{{- end }} + rules: + {{- range .Values.backend.ingress.hosts }} + - host: {{ .host | quote }} + http: + paths: + {{- range .paths }} + - path: {{ . }} + backend: + serviceName: {{ template "substra.fullname" $ }}-backend + servicePort: http + {{- end }} + {{- end }} +{{- end }} diff --git a/charts/substra-backend/templates/job-add-incoming-nodes.yaml b/charts/substra-backend/templates/job-add-incoming-nodes.yaml new file mode 100644 index 000000000..f2494bb34 --- /dev/null +++ b/charts/substra-backend/templates/job-add-incoming-nodes.yaml @@ -0,0 +1,146 @@ +{{- range $index, $value := .Values.incomingNodes }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ template "substra.fullname" $ }}-add-incoming-nodes-{{ $index }} + labels: + app.kubernetes.io/managed-by: {{ $.Release.Service }} + app.kubernetes.io/instance: {{ $.Release.Name }} + helm.sh/chart: {{ $.Chart.Name }}-{{ $.Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" $ }}-add-incoming-nodes-{{ $index }} + app.kubernetes.io/part-of: {{ template "substra.name" $ }} +spec: + template: + spec: + restartPolicy: OnFailure + {{- if or $.Values.pullSecretsInline $.Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := $.Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range $.Values.backend.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: substra-backend + image: "{{ $.Values.backend.image.repository }}:{{ $.Values.backend.image.tag }}" + imagePullPolicy: "{{ $.Values.backend.image.pullPolicy }}" + command: ["python3"] + args: ["manage.py", "create_incoming_node", {{ .name }}, {{ .secret }}] + env: + - name: ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_{{ $.Values.organization.name | upper }}_DB_NAME + value: {{ $.Values.postgresql.postgresqlDatabase }} + - name: BACKEND_DB_USER + value: {{ $.Values.postgresql.postgresqlUsername }} + - name: BACKEND_DB_PWD + value: {{ $.Values.postgresql.postgresqlPassword }} + - name: DATABASE_HOST + value: {{ $.Release.Name }}-postgresql + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ $.Values.backend.settings }} + - name: FABRIC_CFG_PATH + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ $.Values.peer.host }}:{{ $.Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: DEFAULT_DOMAIN + value: "{{ $.Values.backend.defaultDomain }}" + - name: CELERY_BROKER_URL + value: "amqp://{{ $.Values.rabbitmq.rabbitmq.username }}:{{ $.Values.rabbitmq.rabbitmq.password }}@{{ $.Release.Name }}-{{ $.Values.rabbitmq.host }}:{{ $.Values.rabbitmq.port }}//" + - name: BACK_AUTH_USER + value: {{ $.user | quote }} + - name: BACK_AUTH_PASSWORD + value: {{ $.password | quote }} + - name: BACKEND_DEFAULT_PORT + value: {{ $.Values.backend.service.port | quote}} + - name: BACKEND_PEER_PORT + value: "internal" + - name: LEDGER_CONFIG_FILE + value: /conf/{{ $.Values.organization.name }}/substra-backend/conf.json + - name: PYTHONUNBUFFERED + value: "1" + - name: MEDIA_ROOT + value: {{ $.Values.persistence.hostPath }}/medias/ + volumeMounts: + - mountPath: {{ $.Values.persistence.hostPath }} + name: data + - mountPath: /conf/{{ $.Values.organization.name }}/substra-backend + name: config + readOnly: true + - mountPath: /var/hyperledger/fabric_cfg + name: fabric + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + volumes: + - name: data + persistentVolumeClaim: + claimName: {{ include "substra.fullname" $ }} + - name: config + configMap: + name: {{ include "substra.fullname" $ }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ $.Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ $.Values.secrets.key }} + - name: cacert + secret: + secretName: {{ $.Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ $.Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ $.Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ $.Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ $.Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ $.Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} +{{- end }} diff --git a/charts/substra-backend/templates/job-add-outgoing-nodes.yaml b/charts/substra-backend/templates/job-add-outgoing-nodes.yaml new file mode 100644 index 000000000..dbf606e8f --- /dev/null +++ b/charts/substra-backend/templates/job-add-outgoing-nodes.yaml @@ -0,0 +1,146 @@ +{{- range $index, $value := .Values.outgoingNodes }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ template "substra.fullname" $ }}-add-outgoing-nodes-{{ $index }} + labels: + app.kubernetes.io/managed-by: {{ $.Release.Service }} + app.kubernetes.io/instance: {{ $.Release.Name }} + helm.sh/chart: {{ $.Chart.Name }}-{{ $.Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" $ }}-add-outgoing-nodes-{{ $index }} + app.kubernetes.io/part-of: {{ template "substra.name" $ }} +spec: + template: + spec: + restartPolicy: OnFailure + {{- if or $.Values.pullSecretsInline $.Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range $index, $value := $.Values.pullSecretsInline }} + - name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + {{- end }} + {{- range $.Values.backend.image.pullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: substra-backend + image: "{{ $.Values.backend.image.repository }}:{{ $.Values.backend.image.tag }}" + imagePullPolicy: "{{ $.Values.backend.image.pullPolicy }}" + command: ["python3"] + args: ["manage.py", "create_outgoing_node", {{ .name }}, {{ .secret }}] + env: + - name: ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_{{ $.Values.organization.name | upper }}_DB_NAME + value: {{ $.Values.postgresql.postgresqlDatabase }} + - name: BACKEND_DB_USER + value: {{ $.Values.postgresql.postgresqlUsername }} + - name: BACKEND_DB_PWD + value: {{ $.Values.postgresql.postgresqlPassword }} + - name: DATABASE_HOST + value: {{ $.Release.Name }}-postgresql + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ $.Values.backend.settings }} + - name: FABRIC_CFG_PATH + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ $.Values.peer.host }}:{{ $.Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: DEFAULT_DOMAIN + value: "{{ $.Values.backend.defaultDomain }}" + - name: CELERY_BROKER_URL + value: "amqp://{{ $.Values.rabbitmq.rabbitmq.username }}:{{ $.Values.rabbitmq.rabbitmq.password }}@{{ $.Release.Name }}-{{ $.Values.rabbitmq.host }}:{{ $.Values.rabbitmq.port }}//" + - name: BACK_AUTH_USER + value: {{ $.user | quote }} + - name: BACK_AUTH_PASSWORD + value: {{ $.password | quote }} + - name: BACKEND_DEFAULT_PORT + value: {{ $.Values.backend.service.port | quote}} + - name: BACKEND_PEER_PORT + value: "internal" + - name: LEDGER_CONFIG_FILE + value: /conf/{{ $.Values.organization.name }}/substra-backend/conf.json + - name: PYTHONUNBUFFERED + value: "1" + - name: MEDIA_ROOT + value: {{ $.Values.persistence.hostPath }}/medias/ + volumeMounts: + - mountPath: {{ $.Values.persistence.hostPath }} + name: data + - mountPath: /conf/{{ $.Values.organization.name }}/substra-backend + name: config + readOnly: true + - mountPath: /var/hyperledger/fabric_cfg + name: fabric + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + volumes: + - name: data + persistentVolumeClaim: + claimName: {{ include "substra.fullname" $ }} + - name: config + configMap: + name: {{ include "substra.fullname" $ }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ $.Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ $.Values.secrets.key }} + - name: cacert + secret: + secretName: {{ $.Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ $.Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ $.Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ $.Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ $.Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ $.Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} +{{- end }} diff --git a/charts/substra-backend/templates/job-add-users.yaml b/charts/substra-backend/templates/job-add-users.yaml new file mode 100644 index 000000000..319d1b2b2 --- /dev/null +++ b/charts/substra-backend/templates/job-add-users.yaml @@ -0,0 +1,143 @@ +{{- range $index, $value := .Values.users }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ template "substra.fullname" $ }}-add-users-{{ $index }} + labels: + app.kubernetes.io/managed-by: {{ $.Release.Service }} + app.kubernetes.io/instance: {{ $.Release.Name }} + helm.sh/chart: {{ $.Chart.Name }}-{{ $.Chart.Version }} + app.kubernetes.io/name: {{ template "substra.name" $ }}-add-users-{{ $index }} + app.kubernetes.io/part-of: {{ template "substra.name" $ }} +spec: + template: + spec: + restartPolicy: OnFailure + {{- with $.Values.backend.image.pullSecrets }} + imagePullSecrets: + {{- range . }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: backend + image: "{{ $.Values.backend.image.repository }}:{{ $.Values.backend.image.tag }}" + imagePullPolicy: "{{ $.Values.backend.image.pullPolicy }}" + command: ["python3"] + args: ["manage.py", "add_user", {{ .name }}, {{ .secret }}] + env: + - name: ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_ORG + value: {{ $.Values.organization.name }} + - name: BACKEND_{{ $.Values.organization.name | upper }}_DB_NAME + value: {{ $.Values.postgresql.postgresqlDatabase }} + - name: BACKEND_DB_USER + value: {{ $.Values.postgresql.postgresqlUsername }} + - name: BACKEND_DB_PWD + value: {{ $.Values.postgresql.postgresqlPassword }} + - name: DATABASE_HOST + value: {{ $.Release.Name }}-postgresql + - name: DJANGO_SETTINGS_MODULE + value: backend.settings.{{ $.Values.backend.settings }} + - name: FABRIC_CFG_PATH + value: /var/hyperledger/fabric_cfg + - name: CORE_PEER_ADDRESS_ENV + value: "{{ $.Values.peer.host }}:{{ $.Values.peer.port }}" + - name: FABRIC_LOGGING_SPEC + value: debug + - name: DEFAULT_DOMAIN + value: "{{ $.Values.backend.defaultDomain }}" + - name: CELERY_BROKER_URL + value: "amqp://{{ $.Values.rabbitmq.rabbitmq.username }}:{{ $.Values.rabbitmq.rabbitmq.password }}@{{ $.Release.Name }}-{{ $.Values.rabbitmq.host }}:{{ $.Values.rabbitmq.port }}//" + - name: BACK_AUTH_USER + value: {{ $.user | quote }} + - name: BACK_AUTH_PASSWORD + value: {{ $.password | quote }} + - name: BACKEND_DEFAULT_PORT + value: {{ $.Values.backend.service.port | quote}} + - name: BACKEND_PEER_PORT + value: "internal" + - name: LEDGER_CONFIG_FILE + value: /conf/{{ $.Values.organization.name }}/backend/conf.json + - name: PYTHONUNBUFFERED + value: "1" + - name: MEDIA_ROOT + value: {{ $.Values.persistence.hostPath }}/medias/ + volumeMounts: + - mountPath: {{ $.Values.persistence.hostPath }} + name: data + - mountPath: /conf/{{ $.Values.organization.name }}/backend + name: config + readOnly: true + - mountPath: /var/hyperledger/fabric_cfg + name: fabric + readOnly: true + - mountPath: /var/hyperledger/msp/signcerts + name: id-cert + - mountPath: /var/hyperledger/msp/keystore + name: id-key + - mountPath: /var/hyperledger/msp/cacerts + name: cacert + - mountPath: /var/hyperledger/msp/admincerts + name: admin-cert + - mountPath: /var/hyperledger/tls/server/pair + name: tls + - mountPath: /var/hyperledger/tls/server/cert + name: tls-rootcert + - mountPath: /var/hyperledger/tls/client/pair + name: tls-client + - mountPath: /var/hyperledger/tls/client/cert + name: tls-clientrootcert + - mountPath: /var/hyperledger/tls/ord/cert + name: ord-tls-rootcert + - mountPath: /var/hyperledger/admin_msp/signcerts + name: admin-cert + - mountPath: /var/hyperledger/admin_msp/keystore + name: admin-key + - mountPath: /var/hyperledger/admin_msp/cacerts + name: cacert + - mountPath: /var/hyperledger/admin_msp/admincerts + name: admin-cert + volumes: + - name: data + persistentVolumeClaim: + claimName: {{ include "substra.fullname" $ }} + - name: config + configMap: + name: {{ include "substra.fullname" $ }}-backend + - name: fabric + configMap: + name: {{ $.Values.secrets.fabricConfigmap }} + - name: id-cert + secret: + secretName: {{ $.Values.secrets.cert }} + - name: id-key + secret: + secretName: {{ $.Values.secrets.key }} + - name: cacert + secret: + secretName: {{ $.Values.secrets.caCert }} + - name: tls + secret: + secretName: {{ $.Values.secrets.tls }} + - name: tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} + - name: tls-client + secret: + secretName: {{ $.Values.secrets.tlsClient }} + - name: tls-clientrootcert + secret: + secretName: {{ $.Values.secrets.tlsClientRootCerts }} + - name: admin-cert + secret: + secretName: {{ $.Values.secrets.adminCert }} + - name: admin-key + secret: + secretName: {{ $.Values.secrets.adminKey }} + - name: ord-tls-rootcert + secret: + secretName: {{ $.Values.secrets.tlsRootCert }} +{{- end }} diff --git a/charts/substra-backend/templates/secret-pull.yaml b/charts/substra-backend/templates/secret-pull.yaml new file mode 100644 index 000000000..93e364044 --- /dev/null +++ b/charts/substra-backend/templates/secret-pull.yaml @@ -0,0 +1,16 @@ +{{- range $index, $value := .Values.pullSecretsInline }} +--- +apiVersion: v1 +data: + .dockerconfigjson: {{ $value }} +kind: Secret +metadata: + name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + labels: + app.kubernetes.io/managed-by: {{ $.Release.Service }} + app.kubernetes.io/instance: {{ $.Release.Name }} + helm.sh/chart: {{ $.Chart.Name }}-{{ $.Chart.Version }} + app.kubernetes.io/name: {{ template "substra.fullname" $ }}-pull-secret-{{ $index }} + app.kubernetes.io/part-of: {{ template "substra.name" $ }} +type: kubernetes.io/dockerconfigjson +{{- end }} diff --git a/charts/substra-backend/templates/service-backend.yaml b/charts/substra-backend/templates/service-backend.yaml new file mode 100644 index 000000000..1445ea544 --- /dev/null +++ b/charts/substra-backend/templates/service-backend.yaml @@ -0,0 +1,50 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ template "substra.fullname" . }}-backend + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-backend + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/part-of: {{ template "substra.name" . }} + {{- if .Values.backend.service.labels }} + {{- toYaml .Values.backend.service.labels | nindent 4 }} + {{- end }} + {{- if .Values.backend.service.annotations }} + annotations: + {{- toYaml .Values.backend.service.annotations | nindent 4 }} + {{- end }} +spec: +{{- if (or (eq .Values.backend.service.type "ClusterIP") (empty .Values.backend.service.type)) }} + type: ClusterIP + {{- if .Values.backend.service.clusterIP }} + clusterIP: {{ .Values.backend.service.clusterIP }} + {{end}} +{{- else if eq .Values.backend.service.type "LoadBalancer" }} + type: {{ .Values.backend.service.type }} + {{- if .Values.backend.service.loadBalancerIP }} + loadBalancerIP: {{ .Values.backend.service.loadBalancerIP }} + {{- end }} + {{- if .Values.backend.service.loadBalancerSourceRanges }} + loadBalancerSourceRanges: +{{ toYaml .Values.backend.service.loadBalancerSourceRanges | indent 4 }} + {{- end -}} +{{- else }} + type: {{ .Values.backend.service.type }} +{{- end }} +{{- if .Values.backend.service.externalIPs }} + externalIPs: +{{ toYaml .Values.backend.service.externalIPs | indent 4 }} +{{- end }} + ports: + - name: http + port: {{ .Values.backend.service.port }} + protocol: TCP + targetPort: 8000 +{{ if (and (eq .Values.backend.service.type "NodePort") (not (empty .Values.backend.service.nodePort))) }} + nodePort: {{.Values.backend.service.nodePort}} +{{ end }} + selector: + app.kubernetes.io/name: {{ template "substra.name" . }}-backend + app.kubernetes.io/instance: {{ .Release.Name }} diff --git a/charts/substra-backend/templates/service-flower.yaml b/charts/substra-backend/templates/service-flower.yaml new file mode 100644 index 000000000..2ac3a107f --- /dev/null +++ b/charts/substra-backend/templates/service-flower.yaml @@ -0,0 +1,52 @@ +{{- if .Values.flower.enabled -}} +apiVersion: v1 +kind: Service +metadata: + name: {{ template "substra.fullname" . }}-flower + labels: + app.kubernetes.io/name: {{ template "substra.name" . }}-flower + helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }} + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/instance: {{ .Release.Name }} + app.kubernetes.io/part-of: {{ template "substra.name" . }} + {{- if .Values.flower.service.labels }} + {{- toYaml .Values.flower.service.labels | nindent 4 }} + {{- end }} + {{- if .Values.flower.service.annotations }} + annotations: + {{- toYaml .Values.flower.service.annotations | nindent 4 }} + {{- end }} +spec: +{{- if (or (eq .Values.flower.service.type "ClusterIP") (empty .Values.flower.service.type)) }} + type: ClusterIP + {{- if .Values.flower.service.clusterIP }} + clusterIP: {{ .Values.flower.service.clusterIP }} + {{end}} +{{- else if eq .Values.flower.service.type "LoadBalancer" }} + type: {{ .Values.flower.service.type }} + {{- if .Values.flower.service.loadBalancerIP }} + loadBalancerIP: {{ .Values.flower.service.loadBalancerIP }} + {{- end }} + {{- if .Values.flower.service.loadBalancerSourceRanges }} + loadBalancerSourceRanges: +{{ toYaml .Values.flower.service.loadBalancerSourceRanges | indent 4 }} + {{- end -}} +{{- else }} + type: {{ .Values.flower.service.type }} +{{- end }} +{{- if .Values.flower.service.externalIPs }} + externalIPs: +{{ toYaml .Values.flower.service.externalIPs | indent 4 }} +{{- end }} + ports: + - name: http + port: {{ .Values.flower.service.port }} + protocol: TCP + targetPort: 5555 +{{ if (and (eq .Values.flower.service.type "NodePort") (not (empty .Values.flower.service.nodePort))) }} + nodePort: {{.Values.flower.service.nodePort}} +{{ end }} + selector: + app.kubernetes.io/name: {{ template "substra.name" . }}-flower + app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} diff --git a/charts/substra-backend/templates/storage.yaml b/charts/substra-backend/templates/storage.yaml new file mode 100644 index 000000000..213671680 --- /dev/null +++ b/charts/substra-backend/templates/storage.yaml @@ -0,0 +1,27 @@ +--- +kind: PersistentVolumeClaim +apiVersion: v1 +metadata: + name: {{ template "substra.fullname" . }} +spec: + storageClassName: "" + accessModes: + - ReadWriteMany + resources: + requests: + storage: {{ .Values.persistence.size | quote }} +--- +apiVersion: v1 +kind: PersistentVolume +metadata: + name: {{ template "substra.fullname" . }} +spec: + storageClassName: "" + persistentVolumeReclaimPolicy: Recycle + capacity: + storage: {{ .Values.persistence.size | quote }} + accessModes: + - ReadWriteMany + hostPath: + path: {{ .Values.persistence.hostPath | quote }} + type: DirectoryOrCreate diff --git a/charts/substra-backend/values.yaml b/charts/substra-backend/values.yaml new file mode 100644 index 000000000..eb1667e23 --- /dev/null +++ b/charts/substra-backend/values.yaml @@ -0,0 +1,222 @@ +gpu: + enabled: false + platform: ubuntu # or cos + +docker: + # Path of the docker socket on the host + socket: /var/run/docker.sock + # Dockerconfig to be used to pull the images (base64'd) + config: null + # Images to pull + pullImages: [] + # - substrafoundation/substra-tools:0.0.1 + +# Inline secrets used to pull images of pods (base64'd) +pullSecretsInline: [] + +backend: + replicaCount: 1 + settings: prod + siteHost: localhost + defaultDomain: localhost + + image: + repository: substrafoundation/substra-backend + tag: latest + pullPolicy: IfNotPresent + pullSecrets: [] + + service: + type: NodePort + port: 8000 + annotations: {} + labels: {} + clusterIP: "" + externalIPs: [] + loadBalancerIP: "" + loadBalancerSourceRanges: [] + # nodePort: 30000 + + ingress: + enabled: false + annotations: {} + # kubernetes.io/ingress.class: nginx + # kubernetes.io/tls-acme: "true" + hosts: + - host: chart-example.local + paths: [] + + tls: [] + # - secretName: chart-example-tls + # hosts: + # - chart-example.local + + resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 100m + # memory: 128Mi + # requests: + # cpu: 100m + # memory: 128Mi + +outgoingNodes: [] + # - name: nodeId + # secret: nodeSecret +incomingNodes: [] + # - name: nodeId + # secret: nodeSecret + +users: [] +# - name: username +# secret: password + +persistence: + hostPath: "/substra" + size: "10Gi" + +# Secrets names +secrets: + # Certificate, saved under key 'cert.pem' + cert: hlf-idcert + # Key, saved under 'key.pem' + key: hlf-idkey + # CA Cert, saved under 'cacert.pem' + caCert: hlf-cacert + # TLS secret, saved under keys 'tls.crt' and 'tls.key' (to conform with K8S nomenclature) + tls: hlf-tls + # TLS root CA certificate saved under key 'cert.pem' + tlsRootCert: hlf-tlsrootcert + # TLS client root CA certificates saved under any names (as there may be multiple) + tlsClient: hlf-tls + # TLS client root CA certificates saved under any names (as there may be multiple) + tlsClientRootCerts: hlf-client-tlsrootcert + # This should contain the Certificate of the Peer Organisation admin + # This is necessary to successfully run the peer + adminCert: hlf-admincert + # This should contain the Private Key of the Peer Organisation admin + # This is necessary to successfully join a channel + adminKey: hlf-adminkey + # This should include the Orderer TLS 'cacert.pem' + ordTlsRootCert: hlf-client-tlsrootcert + # This will include the organization config json file (peer only) + orgConfig: org-config + # This will include the organization config json file (peer only) + fabricConfigmap: network-hlf-k8s-fabric + +organization: + name: substra + +user: + name: user + +orderer: + host: orderer-hlf-ord + port: 7050 + +peer: + host: healthchain-peer.owkin.com + port: 443 + mspID: OwkinPeerMSP + +channel: mychannel + +chaincode: + name: mycc + version: "1.0" + +postgresql: + enabled: true + postgresqlDatabase: substra + postgresqlUsername: postgres + postgresqlPassword: postgres + persistence: + enabled: false + +rabbitmq: + enabled: true + rabbitmq: + username: rabbitmq + password: rabbitmq + host: rabbitmq + port: 5672 + persistence: + enabled: false + +flower: + enabled: true + host: flower + port: 5555 + persistence: + enabled: false + + image: + repository: substrafoundation/flower + tag: latest + pullPolicy: IfNotPresent + pullSecrets: [] + + service: + type: NodePort + port: 5555 + annotations: {} + labels: {} + clusterIP: "" + externalIPs: [] + loadBalancerIP: "" + loadBalancerSourceRanges: [] + +celerybeat: + replicaCount: 1 + image: + repository: substrafoundation/celerybeat + tag: latest + pullPolicy: IfNotPresent + pullSecrets: [] + + resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 100m + # memory: 128Mi + # requests: + # cpu: 100m + # memory: 128Mi + + nodeSelector: {} + + tolerations: [] + + affinity: {} + +celeryworker: + replicaCount: 1 + image: + repository: substrafoundation/celeryworker + tag: latest + pullPolicy: IfNotPresent + pullSecrets: [] + + resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 100m + # memory: 128Mi + # requests: + # cpu: 100m + # memory: 128Mi + + nodeSelector: {} + + tolerations: [] + + affinity: {} diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index a39478f98..000000000 --- a/docker/README.md +++ /dev/null @@ -1,10 +0,0 @@ -First build all the images with the `build-docker-images.sh` in the root directory of this repository - - -Use classical docker-compose command in the root directory of this repository with `-f` and ` --project-directory` options. - -For instance, `up -d`: - -``` docker-compose -f docker/docker-compose.yaml --project-directory . up -d ``` - -To test from scratch, you may have to remove the `/substra/backup/postgres-data/` directory in the root directory of this repository. diff --git a/docker/celerybeat/Dockerfile b/docker/celerybeat/Dockerfile index 144cdad4f..6e5768419 100644 --- a/docker/celerybeat/Dockerfile +++ b/docker/celerybeat/Dockerfile @@ -7,10 +7,12 @@ RUN apt-get install -y git curl netcat RUN mkdir -p /usr/src/app WORKDIR /usr/src/app -COPY ./substrabac/requirements.txt /usr/src/app/. +COPY ./backend/requirements.txt /usr/src/app/. RUN pip3 install -r requirements.txt -COPY ./substrabac/libs /usr/src/app/libs -COPY ./substrabac/substrapp /usr/src/app/substrapp -COPY ./substrabac/substrabac /usr/src/app/substrabac +COPY ./backend/libs /usr/src/app/libs +COPY ./backend/substrapp /usr/src/app/substrapp +COPY ./backend/backend /usr/src/app/backend +COPY ./backend/node /usr/src/app/node +COPY ./backend/users /usr/src/app/users diff --git a/docker/celeryworker/Dockerfile b/docker/celeryworker/Dockerfile index 6bf39448d..939d2d69f 100644 --- a/docker/celeryworker/Dockerfile +++ b/docker/celeryworker/Dockerfile @@ -7,16 +7,12 @@ RUN apt-get install -y git curl netcat RUN mkdir -p /usr/src/app WORKDIR /usr/src/app -COPY ./bootstrap.sh /usr/src/ -RUN cd ../; sh bootstrap.sh; cd app - -COPY ./substrabac/requirements.txt /usr/src/app/. +COPY ./backend/requirements.txt /usr/src/app/. RUN pip3 install -r requirements.txt -COPY ./substrabac/libs /usr/src/app/libs -COPY ./substrabac/base_metrics /usr/src/app/base_metrics -COPY ./substrabac/fake_metrics /usr/src/app/fake_metrics -COPY ./substrabac/fake_data_sample /usr/src/app/fake_data_sample -COPY ./substrabac/substrapp /usr/src/app/substrapp -COPY ./substrabac/substrabac /usr/src/app/substrabac +COPY ./backend/libs /usr/src/app/libs +COPY ./backend/substrapp /usr/src/app/substrapp +COPY ./backend/backend /usr/src/app/backend +COPY ./backend/node /usr/src/app/node +COPY ./backend/users /usr/src/app/users diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml deleted file mode 100644 index c37b59666..000000000 --- a/docker/docker-compose.yaml +++ /dev/null @@ -1,150 +0,0 @@ -version: '2.3' - -networks: - default: - external: - name: net_substra - -services: - - substrabacowkin: - container_name: owkin.substrabac - hostname: substrabacowkin - image: substra/substrabac - command: /bin/bash -c "while ! { nc -z postgresql 5432 2>&1; }; do sleep 1; done; python manage.py migrate --settings=substrabac.settings.dev; python3 manage.py runserver 0.0.0.0:8000" - volumes: - - /substra:/substra - - /substra/data/orgs/owkin/user/msp:/opt/gopath/src/github.com/hyperledger/fabric/peer/msp - links: - - postgresql - - rabbit - ports: - - "8000:8000" - depends_on: - - postgresql - - rabbit - environment: - - DATABASE_HOST=postgresql - - DJANGO_SETTINGS_MODULE=substrabac.settings.dev - - PYTHONUNBUFFERED=1 - - FABRIC_CFG_PATH=/substra/conf/owkin/peer1/ - - substrabacchunantes: - container_name: chunantes.substrabac - hostname: substrabacchunantes - image: substra/substrabac - command: /bin/bash -c "while ! { nc -z postgresql 5432 2>&1; }; do sleep 1; done; python manage.py migrate --settings=substrabac.settings.dev; python3 manage.py runserver 0.0.0.0:8001" - volumes: - - /substra:/substra - - /substra/data/orgs/chu-nantes/user/msp:/opt/gopath/src/github.com/hyperledger/fabric/peer/msp - links: - - postgresql - - rabbit - ports: - - "8001:8001" - depends_on: - - postgresql - - rabbit - environment: - - DATABASE_HOST=postgresql - - DJANGO_SETTINGS_MODULE=substrabac.settings.dev - - PYTHONUNBUFFERED=1 - - FABRIC_CFG_PATH=/substra/conf/chu-nantes/peer1/ - - # Celery worker - - celerybeat: - container_name: celerybeat - hostname: celerybeat - image: substra/celerybeat - command: /bin/bash -c "while ! { nc -z rabbit 5672 2>&1; }; do sleep 1; done; celery -A substrabac beat -l info -b rabbit" - volumes: - - /substra:/substra - links: - - rabbit - depends_on: - - rabbit - environment: - - PYTHONUNBUFFERED=1 - - DJANGO_SETTINGS_MODULE=substrabac.settings.common - - worker_owkin: - container_name: worker_owkin - hostname: worker_owkin - # runtime: nvidia - image: substra/celeryworker - command: /bin/bash -c "while ! { nc -z rabbit 5672 2>&1; }; do sleep 1; done; celery -A substrabac worker -l info -n owkin -Q owkin,celery -b rabbit" - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - /substra:/substra - - /substra/data/orgs/owkin/user/msp:/opt/gopath/src/github.com/hyperledger/fabric/peer/msp - links: - - rabbit - - substrabacowkin - depends_on: - - rabbit - environment: - - ORG=owkin - - PYTHONUNBUFFERED=1 - - DATABASE_HOST=postgresql - - DJANGO_SETTINGS_MODULE=substrabac.settings.dev - - FABRIC_CFG_PATH=/substra/conf/owkin/peer1/ - - worker_chunantes: - container_name: worker_chunantes - hostname: worker_chunantes - # runtime: nvidia - image: substra/celeryworker - command: /bin/bash -c "while ! { nc -z rabbit 5672 2>&1; }; do sleep 1; done; celery -A substrabac worker -l info -n chunantes -Q chu-nantes,celery -b rabbit" - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - /substra:/substra - - /substra/data/orgs/chu-nantes/user/msp:/opt/gopath/src/github.com/hyperledger/fabric/peer/msp - links: - - rabbit - depends_on: - - rabbit - - substrabacchunantes - environment: - - ORG=chu-nantes - - PYTHONUNBUFFERED=1 - - DATABASE_HOST=postgresql - - DJANGO_SETTINGS_MODULE=substrabac.settings.dev - - FABRIC_CFG_PATH=/substra/conf/chu-nantes/peer1/ - - - rabbit: - container_name: rabbit - hostname: rabbit - image: rabbitmq:3 - environment: - - RABBITMQ_DEFAULT_USER=guest - - RABBITMQ_DEFAULT_PASS=guest - - HOSTNAME=rabbitmq - - RABBITMQ_NODENAME=rabbitmq -# ports: -# - "5672:5672" - - postgresql: - container_name: postgresql - hostname: postgresql - image: substra/postgresql - volumes: - - /substra/backup/postgres-data:/var/lib/postgresql/data -# ports: -# - "5432:5432" - environment: - - POSTGRES_USER=postgres - - USER=postgres - - POSTGRES_PASSWORD=postgrespwd - - POSTGRES_DB=substrabac - -# flower: -# container_name: flower -# image: mher/flower -# ports: -# - "5555:5555" -# depends_on: -# - rabbit -# environment: -# - CELERY_BROKER_URL=amqp://rabbit:5672 diff --git a/docker/flower/Dockerfile b/docker/flower/Dockerfile new file mode 100644 index 000000000..c4020b7af --- /dev/null +++ b/docker/flower/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.6 + +RUN apt-get update +RUN apt-get install -y python3 python3-pip python3-dev build-essential libssl-dev libffi-dev libxml2-dev libxslt1-dev zlib1g-dev +RUN apt-get install -y git curl netcat + +RUN mkdir -p /usr/src/app +WORKDIR /usr/src/app + +COPY ./backend/requirements.txt /usr/src/app/. + +RUN pip3 install -r requirements.txt +RUN pip3 install flower + +COPY ./backend/libs /usr/src/app/libs +COPY ./backend/substrapp /usr/src/app/substrapp +COPY ./backend/backend /usr/src/app/backend +COPY ./backend/node /usr/src/app/node +COPY ./backend/users /usr/src/app/users diff --git a/docker/postgresql/Dockerfile b/docker/postgresql/Dockerfile new file mode 100644 index 000000000..8935ec280 --- /dev/null +++ b/docker/postgresql/Dockerfile @@ -0,0 +1,3 @@ +FROM library/postgres:10.5 + +COPY docker/postgresql/init.sh /docker-entrypoint-initdb.d/init.sh diff --git a/docker/postgresql/init.sh b/docker/postgresql/init.sh index 92590b661..bde4e6a69 100644 --- a/docker/postgresql/init.sh +++ b/docker/postgresql/init.sh @@ -1,11 +1,11 @@ #!/bin/bash -createdb -U ${USER} -E UTF8 substrabac_owkin -psql -U ${USER} -d substrabac_owkin -c "GRANT ALL PRIVILEGES ON DATABASE substrabac_owkin to substrabac;ALTER ROLE substrabac WITH SUPERUSER CREATEROLE CREATEDB;" +createdb -U ${USER} -E UTF8 backend_owkin +psql -U ${USER} -d backend_owkin -c "GRANT ALL PRIVILEGES ON DATABASE backend_owkin to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" -createdb -U ${USER} -E UTF8 substrabac_chunantes -psql -U ${USER} -d substrabac_chunantes -c "GRANT ALL PRIVILEGES ON DATABASE substrabac_chunantes to substrabac;ALTER ROLE substrabac WITH SUPERUSER CREATEROLE CREATEDB;" +createdb -U ${USER} -E UTF8 backend_chunantes +psql -U ${USER} -d backend_chunantes -c "GRANT ALL PRIVILEGES ON DATABASE backend_chunantes to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" -createdb -U ${USER} -E UTF8 substrabac_clb -psql -U ${USER} -d substrabac_clb -c "GRANT ALL PRIVILEGES ON DATABASE substrabac_clb to substrabac;ALTER ROLE substrabac WITH SUPERUSER CREATEROLE CREATEDB;" +createdb -U ${USER} -E UTF8 backend_clb +psql -U ${USER} -d backend_clb -c "GRANT ALL PRIVILEGES ON DATABASE backend_clb to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" diff --git a/docker/start.py b/docker/start.py index 998381faa..a96a2b883 100644 --- a/docker/start.py +++ b/docker/start.py @@ -6,7 +6,7 @@ from subprocess import call, check_output dir_path = os.path.dirname(os.path.realpath(__file__)) -raven_dryrunner_url = "https://a1c2de65bb0f4120aa11d75bca9b47f6@sentry.io/1402760" +raven_backend_url = "https://cff352ba26fc49f19e01692db93bf951@sentry.io/1317743" raven_worker_url = "https://76abd6b5d11e48ea8a118831c86fc615@sentry.io/1402762" raven_scheduler_url = raven_worker_url @@ -19,14 +19,31 @@ 'clb': 8002 } +BACKEND_CREDENTIALS = { + 'owkin': { + 'username': 'substra', + 'password': 'p@$swr0d44' + }, + 'chunantes': { + 'username': 'substra', + 'password': 'p@$swr0d45' + }, + 'clb': { + 'username': 'substra', + 'password': 'p@$swr0d46' + } +} + +SUBSTRA_FOLDER = os.getenv('SUBSTRA_PATH', '/substra') + def generate_docker_compose_file(conf, launch_settings): # POSTGRES - POSTGRES_USER = 'substrabac' - USER = 'substrabac' - POSTGRES_PASSWORD = 'substrabac' - POSTGRES_DB = 'substrabac' + POSTGRES_USER = 'backend' + USER = 'backend' + POSTGRES_PASSWORD = 'backend' + POSTGRES_DB = 'backend' # RABBITMQ RABBITMQ_DEFAULT_USER = 'guest' @@ -38,196 +55,212 @@ def generate_docker_compose_file(conf, launch_settings): # CELERY CELERY_BROKER_URL = f'amqp://{RABBITMQ_DEFAULT_USER}:{RABBITMQ_DEFAULT_PASS}@{RABBITMQ_DOMAIN}:{RABBITMQ_PORT}//' - try: from ruamel import yaml except ImportError: import yaml + wait_rabbit = f'while ! {{ nc -z {RABBITMQ_DOMAIN} {RABBITMQ_PORT} 2>&1; }}; do sleep 1; done' + wait_psql = 'while ! { nc -z postgresql 5432 2>&1; }; do sleep 1; done' + # Docker compose config - docker_compose = {'substrabac_services': {}, - 'substrabac_tools': {'postgresql': {'container_name': 'postgresql', - 'image': 'library/postgres:10.5', - 'restart': 'unless-stopped', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': [f'POSTGRES_USER={POSTGRES_USER}', - f'USER={USER}', - f'POSTGRES_PASSWORD={POSTGRES_PASSWORD}', - f'POSTGRES_DB={POSTGRES_DB}'], - 'volumes': [ - '/substra/backup/postgres-data:/var/lib/postgresql/data', - f'{dir_path}/postgresql/init.sh:/docker-entrypoint-initdb.d/init.sh'], - }, - 'celerybeat': {'container_name': 'celerybeat', - 'hostname': 'celerybeat', - 'image': 'substra/celerybeat', - 'restart': 'unless-stopped', - 'command': '/bin/bash -c "while ! { nc -z rabbit 5672 2>&1; }; do sleep 1; done; while ! { nc -z postgresql 5432 2>&1; }; do sleep 1; done; celery -A substrabac beat -l info"', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': ['PYTHONUNBUFFERED=1', - f'CELERY_BROKER_URL={CELERY_BROKER_URL}', - f'DJANGO_SETTINGS_MODULE=substrabac.settings.common'], - 'depends_on': ['postgresql', 'rabbit'] - }, - 'rabbit': {'container_name': 'rabbit', - 'hostname': 'rabbitmq', # Must be set to be able to recover from volume - 'restart': 'unless-stopped', - 'image': 'rabbitmq:3', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': [f'RABBITMQ_DEFAULT_USER={RABBITMQ_DEFAULT_USER}', - f'RABBITMQ_DEFAULT_PASS={RABBITMQ_DEFAULT_PASS}', - f'HOSTNAME={RABBITMQ_HOSTNAME}', - f'RABBITMQ_NODENAME={RABBITMQ_NODENAME}'], - 'volumes': ['/substra/backup/rabbit-data:/var/lib/rabbitmq'] - }, - }, - 'path': os.path.join(dir_path, './docker-compose-dynamic.yaml')} + docker_compose = { + 'backend_services': {}, + 'backend_tools': { + 'postgresql': { + 'container_name': 'postgresql', + 'labels': ['substra'], + 'image': 'substra/postgresql', + 'restart': 'unless-stopped', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': [ + f'POSTGRES_USER={POSTGRES_USER}', + f'USER={USER}', + f'POSTGRES_PASSWORD={POSTGRES_PASSWORD}', + f'POSTGRES_DB={POSTGRES_DB}'], + 'volumes': [ + f'{SUBSTRA_FOLDER}/backup/postgres-data:/var/lib/postgresql/data'], + }, + 'celerybeat': { + 'container_name': 'celerybeat', + 'labels': ['substra'], + 'hostname': 'celerybeat', + 'image': 'substra/celerybeat', + 'restart': 'unless-stopped', + 'command': f'/bin/bash -c "{wait_rabbit}; {wait_psql}; ' + 'celery -A backend beat -l info"', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': [ + 'PYTHONUNBUFFERED=1', + f'CELERY_BROKER_URL={CELERY_BROKER_URL}', + f'DJANGO_SETTINGS_MODULE=backend.settings.common'], + 'depends_on': ['postgresql', 'rabbit'] + }, + 'rabbit': { + 'container_name': 'rabbit', + 'labels': ['substra'], + 'hostname': 'rabbitmq', # Must be set to be able to recover from volume + 'restart': 'unless-stopped', + 'image': 'rabbitmq:3-management', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': [ + f'RABBITMQ_DEFAULT_USER={RABBITMQ_DEFAULT_USER}', + f'RABBITMQ_DEFAULT_PASS={RABBITMQ_DEFAULT_PASS}', + f'HOSTNAME={RABBITMQ_HOSTNAME}', + f'RABBITMQ_NODENAME={RABBITMQ_NODENAME}'], + 'volumes': [f'{SUBSTRA_FOLDER}/backup/rabbit-data:/var/lib/rabbitmq'] + }, + 'flower': { + 'container_name': f'flower', + 'labels': ['substra'], + 'hostname': f'flower', + 'ports': ['5555:5555'], + 'image': 'substra/flower', + 'restart': 'unless-stopped', + 'command': 'celery flower -A backend', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': [f'CELERY_BROKER_URL={CELERY_BROKER_URL}', + 'DJANGO_SETTINGS_MODULE=backend.settings.common'], + 'depends_on': ['rabbit', 'postgresql'] + } + }, + 'path': os.path.join(dir_path, './docker-compose-dynamic.yaml')} for org in conf: org_name = org['name'] - orderer_ca = org['orderer']['ca'] - peer = org['peer']['name'] - tls_peer_dir = f'/substra/data/orgs/{org_name}/tls/{peer}' - org_name_stripped = org_name.replace('-', '') port = BACKEND_PORT[org_name_stripped] + credentials = BACKEND_CREDENTIALS[org_name_stripped] cpu_count = os.cpu_count() processes = 2 * int(cpu_count) + 1 if launch_settings == 'prod': - django_server = f'python3 manage.py collectstatic --noinput; uwsgi --http :{port} --module substrabac.wsgi --static-map /static=/usr/src/app/substrabac/statics --master --processes {processes} --threads 2' + django_server = f'python3 manage.py collectstatic --noinput; '\ + f'--module backend.wsgi --static-map /static=/usr/src/app/backend/statics ' \ + f'--master --processes {processes} --threads 2 --need-app' \ + f'--env DJANGO_SETTINGS_MODULE=backend.settings.server.prod uwsgi --http :{port} ' else: - - django_server = f'python3 manage.py runserver 0.0.0.0:{port}' + django_server = f'DJANGO_SETTINGS_MODULE=backend.settings.server.dev ' \ + f'python3 manage.py runserver --noreload 0.0.0.0:{port}' backend_global_env = [ f'ORG={org_name_stripped}', - f'SUBSTRABAC_ORG={org_name}', - f'SUBSTRABAC_DEFAULT_PORT={port}', - 'SUBSTRABAC_PEER_PORT=internal', + f'BACKEND_ORG={org_name}', + f'BACKEND_DEFAULT_PORT={port}', + 'BACKEND_PEER_PORT=internal', + + f'LEDGER_CONFIG_FILE={SUBSTRA_FOLDER}/conf/{org_name}/substra-backend/conf.json', 'PYTHONUNBUFFERED=1', 'DATABASE_HOST=postgresql', + f"TASK_CAPTURE_LOGS=True", + f"TASK_CLEAN_EXECUTION_ENVIRONMENT=True", + f"TASK_CACHE_DOCKER_IMAGES=False", + f'CELERY_BROKER_URL={CELERY_BROKER_URL}', - f'DJANGO_SETTINGS_MODULE=substrabac.settings.{launch_settings}', - - # Basic auth - f"BACK_AUTH_USER={os.environ.get('BACK_AUTH_USER', '')}", - f"BACK_AUTH_PASSWORD={os.environ.get('BACK_AUTH_PASSWORD', '')}", - f"SITE_HOST={os.environ.get('SITE_HOST', 'localhost')}", - f"SITE_PORT={os.environ.get('BACK_PORT', 9000)}", - - # HLF overwrite config from core.yaml - f"FABRIC_CFG_PATH_ENV={org['peer']['docker_core_dir']}", - f"FABRIC_LOGGING_SPEC={FABRIC_LOGGING_SPEC}", - f"CORE_PEER_ADDRESS_ENV={org['peer']['host']}:{org['peer']['port']['internal']}", - f"CORE_PEER_MSPCONFIGPATH={org['core_peer_mspconfigpath']}", - f"CORE_PEER_TLS_CERT_FILE={tls_peer_dir}/server/server.crt", - f"CORE_PEER_TLS_KEY_FILE={tls_peer_dir}/server/server.key", - f"CORE_PEER_TLS_ROOTCERT_FILE={tls_peer_dir}/server/server.pem", - f"CORE_PEER_TLS_CLIENTCERT_FILE={tls_peer_dir}/client/client.crt", - f"CORE_PEER_TLS_CLIENTKEY_FILE={tls_peer_dir}/client/client.key", - f"CORE_PEER_TLS_CLIENTROOTCAS_FILES={tls_peer_dir}/client/client.pem", + f'DJANGO_SETTINGS_MODULE=backend.settings.{launch_settings}', ] hlf_volumes = [ - # config (core.yaml + substrabac/conf.json) - f'/substra/conf/{org_name}:/substra/conf/{org_name}:ro', + # config (core.yaml + substra-backend/conf.json) + f'{SUBSTRA_FOLDER}/conf/{org_name}:{SUBSTRA_FOLDER}/conf/{org_name}:ro', # HLF files - f'{orderer_ca}:{orderer_ca}:ro', - f'{tls_peer_dir}:{tls_peer_dir}:ro', f'{org["core_peer_mspconfigpath"]}:{org["core_peer_mspconfigpath"]}:ro', ] - backend = {'container_name': f'{org_name_stripped}.substrabac', - 'image': 'substra/substrabac', - 'restart': 'unless-stopped', - 'ports': [f'{port}:{port}'], - 'command': f'/bin/bash -c "while ! {{ nc -z postgresql 5432 2>&1; }}; do sleep 1; done; yes | python manage.py migrate; {django_server}"', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': backend_global_env.copy(), - 'volumes': ['/substra/medias:/substra/medias', - '/substra/dryrun:/substra/dryrun', - '/substra/servermedias:/substra/servermedias', - '/substra/static:/usr/src/app/substrabac/statics'] + hlf_volumes, - 'depends_on': ['postgresql', 'rabbit']} - - scheduler = {'container_name': f'{org_name_stripped}.scheduler', - 'hostname': f'{org_name}.scheduler', - 'image': 'substra/celeryworker', - 'restart': 'unless-stopped', - 'command': f'/bin/bash -c "while ! {{ nc -z rabbit 5672 2>&1; }}; do sleep 1; done; while ! {{ nc -z postgresql 5432 2>&1; }}; do sleep 1; done; celery -A substrabac worker -l info -n {org_name_stripped} -Q {org_name},scheduler,celery --hostname {org_name}.scheduler"', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': backend_global_env.copy(), - 'volumes': hlf_volumes, - 'depends_on': [f'substrabac{org_name_stripped}', 'postgresql', 'rabbit']} - - worker = {'container_name': f'{org_name_stripped}.worker', - 'hostname': f'{org_name}.worker', - 'image': 'substra/celeryworker', - 'restart': 'unless-stopped', - 'command': f'/bin/bash -c "while ! {{ nc -z rabbit 5672 2>&1; }}; do sleep 1; done; while ! {{ nc -z postgresql 5432 2>&1; }}; do sleep 1; done; celery -A substrabac worker -l info -n {org_name_stripped} -Q {org_name},{org_name}.worker,celery --hostname {org_name}.worker"', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': backend_global_env.copy(), - 'volumes': ['/var/run/docker.sock:/var/run/docker.sock', - '/substra/medias:/substra/medias', - '/substra/servermedias:/substra/servermedias'] + hlf_volumes, - 'depends_on': [f'substrabac{org_name_stripped}', 'rabbit']} - - dryrunner = {'container_name': f'{org_name_stripped}.dryrunner', - 'hostname': f'{org_name}.dryrunner', - 'image': 'substra/celeryworker', - 'restart': 'unless-stopped', - 'command': f'/bin/bash -c "while ! {{ nc -z rabbit 5672 2>&1; }}; do sleep 1; done; while ! {{ nc -z postgresql 5432 2>&1; }}; do sleep 1; done; celery -A substrabac worker -l info -n {org_name_stripped} -Q {org_name},{org_name}.dryrunner,celery --hostname {org_name}.dryrunner"', - 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, - 'environment': backend_global_env.copy(), - 'volumes': ['/var/run/docker.sock:/var/run/docker.sock', - '/substra/dryrun:/substra/dryrun', - '/substra/medias:/substra/medias', - '/substra/servermedias:/substra/servermedias'] + hlf_volumes, - 'depends_on': [f'substrabac{org_name_stripped}', 'rabbit']} + # HLF files + for tls_key in ['tlsCACerts', 'clientCert', 'clientKey']: + hlf_volumes.append(f'{org["peer"][tls_key]}:{org["peer"][tls_key]}:ro') + + # load incoming/outgoing node fixtures/ that should not be executed in production env + fixtures_command = '' + user_command = '' + if launch_settings == 'dev': + fixtures_command = f"python manage.py init_nodes ./node/nodes/{org_name}MSP.json" + # $ replace is needed for docker-compose $ special variable + user_command = f"python manage.py add_user {credentials['username']} "\ + f"'{credentials['password'].replace('$', '$$')}'" + + backend = { + 'container_name': f'substra-backend.{org_name_stripped}.xyz', + 'labels': ['substra'], + 'image': 'substra/substra-backend', + 'restart': 'unless-stopped', + 'ports': [f'{port}:{port}'], + 'command': f'/bin/bash -c "{wait_rabbit}; {wait_psql}; ' + f'yes | python manage.py migrate; {fixtures_command}; {user_command}; {django_server}"', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': backend_global_env.copy(), + 'volumes': [ + f'{SUBSTRA_FOLDER}/medias:{SUBSTRA_FOLDER}/medias:rw', + f'{SUBSTRA_FOLDER}/servermedias:{SUBSTRA_FOLDER}/servermedias:ro', + f'{SUBSTRA_FOLDER}/static:/usr/src/app/backend/statics'] + hlf_volumes, + 'depends_on': ['postgresql', 'rabbit']} + + scheduler = { + 'container_name': f'{org_name_stripped}.scheduler', + 'labels': ['substra'], + 'hostname': f'{org_name}.scheduler', + 'image': 'substra/celeryworker', + 'restart': 'unless-stopped', + 'command': f'/bin/bash -c "{wait_rabbit}; {wait_psql}; ' + f'celery -A backend worker -l info -n {org_name_stripped} ' + f'-Q {org_name},scheduler,celery --hostname {org_name}.scheduler"', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': backend_global_env.copy(), + 'volumes': hlf_volumes, + 'depends_on': [f'backend{org_name_stripped}', 'postgresql', 'rabbit']} + + worker = { + 'container_name': f'{org_name_stripped}.worker', + 'labels': ['substra'], + 'hostname': f'{org_name}.worker', + 'image': 'substra/celeryworker', + 'restart': 'unless-stopped', + 'command': f'/bin/bash -c "{wait_rabbit}; {wait_psql}; ' + f'celery -A backend worker -l info -n {org_name_stripped} ' + f'-Q {org_name},{org_name}.worker,celery --hostname {org_name}.worker"', + 'logging': {'driver': 'json-file', 'options': {'max-size': '20m', 'max-file': '5'}}, + 'environment': backend_global_env.copy(), + 'volumes': [ + '/var/run/docker.sock:/var/run/docker.sock', + f'{SUBSTRA_FOLDER}/medias:{SUBSTRA_FOLDER}/medias:rw', + f'{SUBSTRA_FOLDER}/servermedias:{SUBSTRA_FOLDER}/servermedias:ro'] + hlf_volumes, + 'depends_on': [f'backend{org_name_stripped}', 'rabbit']} # Check if we have nvidia docker if 'nvidia' in check_output(['docker', 'system', 'info', '-f', '"{{.Runtimes}}"']).decode('utf-8'): worker['runtime'] = 'nvidia' if launch_settings == 'dev': - media_root = f'MEDIA_ROOT=/substra/medias/{org_name_stripped}' - dryrun_root = f'DRYRUN_ROOT=/substra/dryrun/{org_name}' - + media_root = f'MEDIA_ROOT={SUBSTRA_FOLDER}/medias/{org_name_stripped}' worker['environment'].append(media_root) - dryrunner['environment'].append(media_root) backend['environment'].append(media_root) - - dryrunner['environment'].append(dryrun_root) - backend['environment'].append(dryrun_root) else: - default_domain = os.environ.get('SUBSTRABAC_DEFAULT_DOMAIN', '') + default_domain = os.environ.get('BACKEND_DEFAULT_DOMAIN', '') if default_domain: backend['environment'].append(f"DEFAULT_DOMAIN={default_domain}") worker['environment'].append(f"DEFAULT_DOMAIN={default_domain}") scheduler['environment'].append(f"DEFAULT_DOMAIN={default_domain}") - dryrunner['environment'].append(f"DEFAULT_DOMAIN={default_domain}") + backend['environment'].append(f"RAVEN_URL={raven_backend_url}") scheduler['environment'].append(f"RAVEN_URL={raven_scheduler_url}") worker['environment'].append(f"RAVEN_URL={raven_worker_url}") - dryrunner['environment'].append(f"RAVEN_URL={raven_dryrunner_url}") - docker_compose['substrabac_services']['substrabac' + org_name_stripped] = backend - docker_compose['substrabac_services']['scheduler' + org_name_stripped] = scheduler - docker_compose['substrabac_services']['worker' + org_name_stripped] = worker - docker_compose['substrabac_services']['dryrunner' + org_name_stripped] = dryrunner + docker_compose['backend_services']['backend' + org_name_stripped] = backend + docker_compose['backend_services']['scheduler' + org_name_stripped] = scheduler + docker_compose['backend_services']['worker' + org_name_stripped] = worker # Create all services along to conf COMPOSITION = {'services': {}, 'version': '2.3', 'networks': {'default': {'external': {'name': 'net_substra'}}}} - for name, dconfig in docker_compose['substrabac_services'].items(): + for name, dconfig in docker_compose['backend_services'].items(): COMPOSITION['services'][name] = dconfig - for name, dconfig in docker_compose['substrabac_tools'].items(): + for name, dconfig in docker_compose['backend_tools'].items(): COMPOSITION['services'][name] = dconfig with open(docker_compose['path'], 'w+') as f: @@ -242,29 +275,32 @@ def stop(docker_compose=None): if docker_compose is not None: call(['docker-compose', '-f', docker_compose['path'], '--project-directory', os.path.join(dir_path, '../'), 'down', '--remove-orphans']) - else: - call(['docker-compose', '-f', os.path.join(dir_path, './docker-compose.yaml'), '--project-directory', - os.path.join(dir_path, '../'), 'down', '--remove-orphans']) def start(conf, launch_settings, no_backup): - print('Generate docker-compose file\n') - docker_compose = generate_docker_compose_file(conf, launch_settings) + nodes_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'backend/node/nodes') + if not os.path.exists(nodes_path): + print('ERROR: nodes folder does not exist, please run `python ./backend/node/generate_nodes.py`' + ' (you maybe will have to regenerate your docker images)\n') + else: + print('Generate docker-compose file\n') + docker_compose = generate_docker_compose_file(conf, launch_settings) - stop(docker_compose) + stop(docker_compose) - if no_backup: - print('Clean medias directory\n') - call(['sh', os.path.join(dir_path, '../substrabac/scripts/clean_media.sh')]) - print('Remove postgresql database\n') - call(['rm', '-rf', '/substra/backup/postgres-data']) - print('Remove rabbit database\n') - call(['rm', '-rf', '/substra/backup/rabbit-data']) + if no_backup: + print('Clean medias directory\n') + call(['sh', os.path.join(dir_path, '../scripts/clean_media.sh')]) + print('Remove postgresql database\n') + call(['rm', '-rf', f'{SUBSTRA_FOLDER}/backup/postgres-data']) + print('Remove rabbit database\n') + call(['rm', '-rf', f'{SUBSTRA_FOLDER}/backup/rabbit-data']) - print('start docker-compose', flush=True) - call(['docker-compose', '-f', docker_compose['path'], '--project-directory', - os.path.join(dir_path, '../'), 'up', '-d', '--remove-orphans', '--build']) - call(['docker', 'ps', '-a', '--format', 'table {{.ID}}\t{{.Names}}\t{{.Status}}\t{{.Ports}}']) + print('start docker-compose', flush=True) + call(['docker-compose', '-f', docker_compose['path'], '--project-directory', + os.path.join(dir_path, '../'), 'up', '-d', '--remove-orphans', '--build']) + call(['docker', 'ps', '-a', '--format', 'table {{.ID}}\t{{.Names}}\t{{.Status}}\t{{.Ports}}', + '--filter', 'label=substra']) if __name__ == "__main__": @@ -283,9 +319,10 @@ def start(conf, launch_settings, no_backup): no_backup = args['no_backup'] - conf = [json.load(open(file_path, 'r')) for file_path in glob.glob('/substra/conf/*/substrabac/conf.json')] + conf = [json.load(open(file_path, 'r')) + for file_path in glob.glob(f'{SUBSTRA_FOLDER}/conf/*/substra-backend/conf.json')] - print('Build substrabac for : ', flush=True) + print('Build backend for : ', flush=True) print(' Organizations :', flush=True) for org in conf: print(' -', org['name'], flush=True) diff --git a/docker/stop.py b/docker/stop.py index 14df1feb3..269af7c3f 100644 --- a/docker/stop.py +++ b/docker/stop.py @@ -7,15 +7,11 @@ def stop(): print('stopping container') - docker_compose_path = './docker-compose.yaml' - if os.path.exists(os.path.join(dir_path, './docker-compose-dynamic.yaml')): docker_compose_path = './docker-compose-dynamic.yaml' - call(['docker-compose', '-f', os.path.join(dir_path, docker_compose_path), '--project-directory', - os.path.join(dir_path, '../'), 'down', '--remove-orphans']) - - # call(['rm', '-rf', '/substra/backup/postgres-data']) + call(['docker-compose', '-f', os.path.join(dir_path, docker_compose_path), '--project-directory', + os.path.join(dir_path, '../'), 'down', '--remove-orphans']) if __name__ == "__main__": diff --git a/docker/substra-backend/Dockerfile b/docker/substra-backend/Dockerfile new file mode 100644 index 000000000..12dab083f --- /dev/null +++ b/docker/substra-backend/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.6 + +RUN apt-get update +RUN apt-get install -y python3 python3-pip python3-dev build-essential libssl-dev libffi-dev libxml2-dev libxslt1-dev zlib1g-dev g++ gcc gfortran musl-dev postgresql-contrib +RUN apt-get install -y git curl netcat + +RUN mkdir -p /usr/src/app +WORKDIR /usr/src/app + +COPY ./backend/requirements.txt /usr/src/app/. + +RUN pip3 install -r requirements.txt + +COPY ./backend/manage.py /usr/src/app/manage.py +COPY ./backend/libs /usr/src/app/libs +COPY ./backend/substrapp /usr/src/app/substrapp +COPY ./backend/events /usr/src/app/events +COPY ./backend/backend /usr/src/app/backend +COPY ./backend/node /usr/src/app/node +COPY ./backend/node-register /usr/src/app/node-register +COPY ./backend/users /usr/src/app/users diff --git a/docker/substrabac/Dockerfile b/docker/substrabac/Dockerfile deleted file mode 100644 index 053c45257..000000000 --- a/docker/substrabac/Dockerfile +++ /dev/null @@ -1,20 +0,0 @@ -FROM python:3.6 - -RUN apt-get update -RUN apt-get install -y python3 python3-pip python3-dev build-essential libssl-dev libffi-dev libxml2-dev libxslt1-dev zlib1g-dev g++ gcc gfortran musl-dev postgresql-contrib -RUN apt-get install -y git curl netcat - -RUN mkdir -p /usr/src/app -WORKDIR /usr/src/app - -COPY ./bootstrap.sh /usr/src/ -RUN cd ../; sh bootstrap.sh; cd app - -COPY ./substrabac/requirements.txt /usr/src/app/. - -RUN pip3 install -r requirements.txt - -COPY ./substrabac/manage.py /usr/src/app/manage.py -COPY ./substrabac/libs /usr/src/app/libs -COPY ./substrabac/substrapp /usr/src/app/substrapp -COPY ./substrabac/substrabac /usr/src/app/substrabac diff --git a/fabric-sdk-py_tests/fabric-sdk-py-async-events.py b/fabric-sdk-py_tests/fabric-sdk-py-async-events.py deleted file mode 100644 index b32092b64..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py-async-events.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import sys -import asyncio - -from hfc.fabric import Client -from hfc.fabric.block_decoder import FilteredBlockDecoder -from hfc.util.crypto.crypto import ecies - -from hfc.fabric.transaction.tx_context import TXContext -from hfc.fabric.transaction.tx_proposal_request import TXProposalRequest - - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -async def main(): - cli = Client(net_profile=os.path.join(dir_path, '../network.json')) - admin_owkin = cli.get_user('owkin', 'admin') - - cli.new_channel('mychannel') - peer = cli.get_peer('peer1-owkin') - - events = cli.get_events(admin_owkin, peer, 'mychannel', start=0, filtered=True) - - async for v in cli.getEvents(events): - print(v) - - - -asyncio.run(main(), debug=True) - diff --git a/fabric-sdk-py_tests/fabric-sdk-py-discover.py b/fabric-sdk-py_tests/fabric-sdk-py-discover.py deleted file mode 100644 index 0bf3127d5..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py-discover.py +++ /dev/null @@ -1,88 +0,0 @@ -from hfc.fabric import Client -from hfc.fabric.channel.channel import Channel -from hfc.fabric.block_decoder import decode_fabric_MSP_config, decode_fabric_peers_info, decode_fabric_endpoints -from hfc.fabric.peer import create_peer -from hfc.fabric.user import create_user -from hfc.util.crypto.crypto import ecies -from hfc.util.keyvaluestore import FileKeyValueStore - -import pprint -import glob - -peer_config = {'clientKey': {'path': '/substra/data/orgs/owkin/tls/peer1/cli-client.key'}, - 'clientServer': {'path': '/substra/data/orgs/owkin/tls/peer1/cli-client.crt'}, - 'eventUrl': 'peer1-owkin:7053', - 'grpcOptions': {'grpc.http2.keepalive_time': 15, - 'grpc.ssl_target_name_override': 'peer1-owkin'}, - 'tlsCACerts': { - 'path': '/substra/data/orgs/owkin/ca-cert.pem'}, - 'url': 'peer1-owkin:7051'} - -peer1_owkin = create_peer(endpoint=peer_config['url'], - tls_cacerts=peer_config['tlsCACerts']['path'], - client_key=peer_config['clientKey']['path'], - client_cert=peer_config['clientServer']['path'], - opts=[(k, v) for k, v in peer_config['grpcOptions'].items()]) - -key_path = glob.glob('/substra/data/orgs/owkin/admin/msp/keystore/*')[0] -cert_path = '/substra/data/orgs/owkin/admin/msp/signcerts/cert.pem' - -admin_owkin = create_user(name='admin', - org='owkin', - state_store=FileKeyValueStore('/tmp/kvs/'), - msp_id='owkinMSP', - key_path=key_path, - cert_path=cert_path) - - -client = Client() - -print(client.query_peers(admin_owkin, peer1_owkin)) -print(client.query_peers(admin_owkin, peer1_owkin, channel='mychannel', local=False)) - -client.init_with_discovery(admin_owkin, peer1_owkin, - 'mychannel') - -response = Channel('', '')._discovery(admin_owkin, peer1_owkin, config=False, local=True) - -response = Channel('mychannel', '')._discovery(admin_owkin, peer1_owkin, config=True, local=False) - - -def process_config_result(config_result): - - results = {'msps': {}, - 'orderers': {}} - - for msp_name in config_result.msps: - results['msps'][msp_name] = decode_fabric_MSP_config(config_result.msps[msp_name].SerializeToString()) - - for orderer_msp in config_result.orderers: - results['orderers'][orderer_msp] = decode_fabric_endpoints(config_result.orderers[orderer_msp].endpoint) - - return results - - -def process_cc_query_res(cc_query_res): - pass - - -def process_members(members): - peers = [] - for msp_name in members.peers_by_org: - peers.append(decode_fabric_peers_info(members.peers_by_org[msp_name].peers)) - return peers - - -results = {} -for res in response.results: - # print(res) - print('-' * 100) - print('Error') - pprint.pprint(res.error) - print('-' * 50) - print('Config result') - pprint.pprint(process_config_result(res.config_result), indent=2) - # print(f'Chaincode Query result : {res.cc_query_res}') - print('Members') - pprint.pprint(process_members(res.members), indent=2) - print('#' * 100) diff --git a/fabric-sdk-py_tests/fabric-sdk-py-mass-enroll.py b/fabric-sdk-py_tests/fabric-sdk-py-mass-enroll.py deleted file mode 100644 index c9810b2e4..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py-mass-enroll.py +++ /dev/null @@ -1,33 +0,0 @@ -import random -import string - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll admin') -try: - admin = cacli.enroll('admin-owkin', 'admin-owkinpw') -except ValueError as e: - print(e) -else: - print('Admin successfully enrolled') - with open('/substra/data/orgs/owkin/ca-cert.pem', 'rb') as f: - cert = f.read() - - if cacli._ca_client.get_cainfo() == cert: - print('Distant ca cert is the same as in local filesystem') - - for x in range(0, 200): - username = ''.join( - [random.choice(string.ascii_letters + string.digits) for n in - range(9)]) - print(f'Will try to register user {username}') - try: - secret = admin.register(username, role='client', affiliation='owkin.nantes') - except ValueError as e: - print(e) - else: - print(f'Correctly registered user {username} with secret {secret}') diff --git a/fabric-sdk-py_tests/fabric-sdk-py-query-invoke.py b/fabric-sdk-py_tests/fabric-sdk-py-query-invoke.py deleted file mode 100644 index 5bef66917..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py-query-invoke.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import asyncio -import subprocess - -from hfc.fabric import Client - -from substrabac.settings.common import PROJECT_ROOT - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -cli = Client(net_profile=os.path.join(dir_path, '../network.json')) -admin_owkin = cli.get_user('owkin', 'admin') - -cli.new_channel('mychannel') - -loop = asyncio.get_event_loop() - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll admin') -try: - admin = cacli.enroll('admin-owkin', 'admin-owkinpw') -except ValueError as e: - print(e) -except Exception as e: - print(e) -else: - print('Admin enrolled') - - os.environ['FABRIC_CFG_PATH'] = '/substra/conf/owkin/peer1' - os.environ['CORE_PEER_MSPCONFIGPATH'] = '/substra/data/orgs/owkin/user/msp' - - output = subprocess.run([os.path.join(PROJECT_ROOT, '../bin/peer'), - '--logging-level', 'DEBUG', - 'chaincode', 'query', - '-C', 'mychannel', - '-n', 'mycc', - #'--tls', - #'--clientauth', - '-c', '{"Args":["queryDataManagers"]}' - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - data = output.stdout.decode('utf-8') - if data: - print(data) - else: - try: - msg = output.stderr.decode('utf-8').split('Error')[2].split('\n')[0] - data = {'message': msg} - except: - msg = output.stderr.decode('utf-8') - data = {'message': msg} - finally: - print(data) - - response = loop.run_until_complete(cli.chaincode_query( - requestor=admin_owkin, - channel_name='mychannel', - peers=['peer1-owkin'], - args=[], - cc_name='mycc', - cc_version='1.0', - fcn='queryDataManagers' - )) - print(response) - - response = loop.run_until_complete(cli.query_installed_chaincodes( - requestor=admin_owkin, - peers=['peer1-owkin'] - )) - print(response) - - response = loop.run_until_complete(cli.query_channels( - requestor=admin_owkin, - peers=['peer1-owkin'] - )) - print(response) - - response = loop.run_until_complete(cli.query_info( - requestor=admin_owkin, - channel_name='mychannel', - peers=['peer1-owkin'] - )) - print(response) - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - response = loop.run_until_complete(cli.chaincode_invoke( - requestor=admin_owkin, - channel_name='mychannel', - peers=['peer1-owkin'], - args=['ISIC 2018', - '59300f1fec4f5cdd3a236c7260ed72bdd24691efdec63b7910ea84136123cecd', - 'http://chunantes.substrabac:8001/media/data_managers/59300f1fec4f5cdd3a236c7260ed72bdd24691efdec63b7910ea84136123cecd/opener.py', - 'Images', - '59300f1fec4f5cdd3a236c7260ed72bdd24691efdec63b7910ea84136123cecd', - 'http://chunantes.substrabac:8001/media/data_managers/59300f1fec4f5cdd3a236c7260ed72bdd24691efdec63b7910ea84136123cecd/description.md', - '', - 'all' - ], - cc_name='mycc', - cc_version='1.0', - fcn='registerDataManager', - wait_for_event=True, - wait_for_event_timeout=5 - )) - print(response) diff --git a/fabric-sdk-py_tests/fabric-sdk-py.py b/fabric-sdk-py_tests/fabric-sdk-py.py deleted file mode 100644 index 1e8030e89..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py.py +++ /dev/null @@ -1,67 +0,0 @@ -import random -import string - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll admin') -try: - admin = cacli.enroll('admin-owkin', 'admin-owkinpw') -except ValueError as e: - print(e) -else: - print('Admin successfully enrolled') - with open('/substra/data/orgs/owkin/ca-cert.pem', 'rb') as f: - cert = f.read() - - if cacli._ca_client.get_cainfo() == cert: - print('Distant ca cert is the same as in local filesystem') - - username = ''.join( - [random.choice(string.ascii_letters + string.digits) for n in - range(9)]) - print(f'Will try to register user {username}') - try: - secret = admin.register(username) - except ValueError as e: - print(e) - else: - print(f'Correctly registered user {username} with secret {secret}') - - print( - f'Will try to enroll new registered user {username} with secret {secret}') - try: - User = cacli.enroll(username, secret) - except ValueError as e: - print(e) - else: - print(f'User {username} successfully enrolled') - - # reenroll - User = cacli.reenroll(User) - - print( - f'Will try to revoke new registered user {username}') - try: - RevokedCerts, CRL = admin.revoke(username, reason='unspecified') - except ValueError as e: - print(e) - else: - print(f'User {username} successfully revoked') - - print('Will try to enroll bootstrap admin') - try: - bootstrap_admin = cacli.enroll('admin', 'adminpw') - except ValueError as e: - print(e) - else: - try: - newCRL = bootstrap_admin.generateCRL() - except Exception as e: - print('Failed to generate CRL %s', str(e)) - else: - print(newCRL) - diff --git a/fabric-sdk-py_tests/fabric-sdk-py_affiliation_service.py b/fabric-sdk-py_tests/fabric-sdk-py_affiliation_service.py deleted file mode 100644 index 7ef66c4c9..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py_affiliation_service.py +++ /dev/null @@ -1,49 +0,0 @@ -from pprint import pprint - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll bootstrap admin') -try: - bootstrap_admin = cacli.enroll('admin', 'adminpw') -except ValueError as e: - print(e) -else: - print('Admin successfully enrolled') - - print('Create affiliation Service') - - affiliationService = cacli.newAffiliationService() - - affiliation = 'department3' - - print(f'Will try to create affiliation {affiliation}') - res = affiliationService.create(bootstrap_admin, affiliation) - pprint(res) - - print(f'Will try to get affiliation {affiliation}') - res = affiliationService.getOne(affiliation, bootstrap_admin) - pprint(res) - - print('Will try to get all affiliations') - res = affiliationService.getAll(bootstrap_admin) - print('number of affiliations: ', len(res['result']['affiliations'])) - - print(f'Will try to update affiliation {affiliation} with name=\'department3bis\'') - res = affiliationService.update(affiliation, bootstrap_admin, name='department3bis') - pprint(res) - - print(f'Will try to get affiliation {affiliation} to see changes') - res = affiliationService.getOne(affiliation, bootstrap_admin) - pprint(res) - - print(f'Will try to delete affiliation {affiliation}') - res = affiliationService.delete('department3bis', bootstrap_admin) - pprint(res) - - print(f'Will try to get deleted affiliation {affiliation}') - res = affiliationService.getOne(affiliation, bootstrap_admin) - pprint(res) diff --git a/fabric-sdk-py_tests/fabric-sdk-py_certificate_service.py b/fabric-sdk-py_tests/fabric-sdk-py_certificate_service.py deleted file mode 100644 index f51571efd..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py_certificate_service.py +++ /dev/null @@ -1,63 +0,0 @@ -import random -import string - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll admin') -try: - bootstrap_admin = cacli.enroll('admin', 'adminpw') -except ValueError as e: - print(e) -else: - print('Admin successfully enrolled') - - print('Create affiliation Service') - - certificateService = cacli.newCertificateService() - - print(f'Will try to get certificates') - res = certificateService.getCertificates(bootstrap_admin) - print(len(res['result'])) - - print(f'Will try to get certificates admin') - res = certificateService.getCertificates(bootstrap_admin, 'admin') - print(len(res['result'])) - - print('Will try to enroll admin') - try: - admin = cacli.enroll('admin-owkin', 'admin-owkinpw') - except ValueError as e: - print(e) - else: - print('Admin successfully enrolled') - - print('Create identity Service') - - identityService = cacli.newIdentityService() - - username = ''.join( - [random.choice(string.ascii_letters + string.digits) for n in - range(9)]) - print(f'Will try to register user {username}') - secret = identityService.create(admin, username) - print(f'Correctly registered user {username} with secret {secret}') - - print(f'Will try to get certificates {username} from admin') - res = certificateService.getCertificates(admin, username) - print(len(res['result'])) - - print( - f'Will try to enroll user {username} with original password {secret}') - try: - user = cacli.enroll(username, secret) - except: - print('User cannot enroll with old password') - else: - - print(f'Will try to get certificates {username} from user') - res = certificateService.getCertificates(user) - print(res) diff --git a/fabric-sdk-py_tests/fabric-sdk-py_identity_service.py b/fabric-sdk-py_tests/fabric-sdk-py_identity_service.py deleted file mode 100644 index 6083175a6..000000000 --- a/fabric-sdk-py_tests/fabric-sdk-py_identity_service.py +++ /dev/null @@ -1,63 +0,0 @@ -import random -import string -from pprint import pprint - -from hfc.fabric_ca.caservice import ca_service - -cacli = ca_service(target="https://rca-owkin:7054", - ca_certs_path='/substra/data/orgs/owkin/ca-cert.pem', - ca_name='rca-owkin') - -print('Will try to enroll admin') -try: - admin = cacli.enroll('admin-owkin', 'admin-owkinpw') -except ValueError as e: - print(e) -else: - print('Admin successfully enrolled') - - print('Create identity Service') - - identityService = cacli.newIdentityService() - - username = ''.join( - [random.choice(string.ascii_letters + string.digits) for n in - range(9)]) - print(f'Will try to register user {username}') - secret = identityService.create(admin, username) - print(f'Correctly registered user {username} with secret {secret}') - - print(f'Will try to get user {username}') - res = identityService.getOne(username, admin) - pprint(res) - print('Will try to get all users') - res = identityService.getAll(admin) - print('number of users: ', len(res['result']['identities'])) - - print(f'Will try to update user {username} with maxEnrollments=3, affiliation=\'.\' and secret=bar') - res = identityService.update(username, admin, maxEnrollments=3, - affiliation='.', enrollmentSecret='bar') - pprint(res) - - print(f'Will try to enroll user {username} with original password {secret}') - try: - cacli.enroll(username, secret) - except: - print('User cannot enroll with old password') - else: - print('/!\ User password update did not work correctly as he is able to enroll with old password') - finally: - print(f'Will try to enroll user {username} with modified password bar') - cacli.enroll(username, 'bar') - - print(f'Will try to get user {username} to see changes') - res = identityService.getOne(username, admin) - pprint(res) - - print(f'Will try to delete user {username}') - res = identityService.delete(username, admin) - pprint(res) - - print(f'Will try to get deleted user {username}') - res = identityService.getOne(username, admin) - pprint(res) diff --git a/substrabac/fixtures/chunantes/algos/algo0/algo.tar.gz b/fixtures/chunantes/algos/algo0/algo.tar.gz similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo0/algo.tar.gz rename to fixtures/chunantes/algos/algo0/algo.tar.gz diff --git a/substrabac/fixtures/chunantes/algos/algo0/algo.zip b/fixtures/chunantes/algos/algo0/algo.zip similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo0/algo.zip rename to fixtures/chunantes/algos/algo0/algo.zip diff --git a/substrabac/fixtures/chunantes/algos/algo0/description.md b/fixtures/chunantes/algos/algo0/description.md similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo0/description.md rename to fixtures/chunantes/algos/algo0/description.md diff --git a/substrabac/fixtures/chunantes/algos/algo1/algo.tar.gz b/fixtures/chunantes/algos/algo1/algo.tar.gz similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo1/algo.tar.gz rename to fixtures/chunantes/algos/algo1/algo.tar.gz diff --git a/substrabac/fixtures/chunantes/algos/algo1/description.md b/fixtures/chunantes/algos/algo1/description.md similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo1/description.md rename to fixtures/chunantes/algos/algo1/description.md diff --git a/substrabac/fixtures/chunantes/algos/algo2/algo.zip b/fixtures/chunantes/algos/algo2/algo.zip similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo2/algo.zip rename to fixtures/chunantes/algos/algo2/algo.zip diff --git a/fixtures/chunantes/algos/algo3/algo.tar.gz b/fixtures/chunantes/algos/algo3/algo.tar.gz new file mode 100644 index 000000000..f3bbb776a Binary files /dev/null and b/fixtures/chunantes/algos/algo3/algo.tar.gz differ diff --git a/substrabac/fixtures/chunantes/algos/algo3/description.md b/fixtures/chunantes/algos/algo3/description.md similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo3/description.md rename to fixtures/chunantes/algos/algo3/description.md diff --git a/substrabac/fixtures/chunantes/algos/algo4/algo.tar.gz b/fixtures/chunantes/algos/algo4/algo.tar.gz similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo4/algo.tar.gz rename to fixtures/chunantes/algos/algo4/algo.tar.gz diff --git a/substrabac/fixtures/chunantes/algos/algo4/algo.zip b/fixtures/chunantes/algos/algo4/algo.zip similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo4/algo.zip rename to fixtures/chunantes/algos/algo4/algo.zip diff --git a/substrabac/fixtures/chunantes/algos/algo4/description.md b/fixtures/chunantes/algos/algo4/description.md similarity index 100% rename from substrabac/fixtures/chunantes/algos/algo4/description.md rename to fixtures/chunantes/algos/algo4/description.md diff --git a/substrabac/fixtures/chunantes/datamanagers/datamanager0/description.md b/fixtures/chunantes/datamanagers/datamanager0/description.md similarity index 100% rename from substrabac/fixtures/chunantes/datamanagers/datamanager0/description.md rename to fixtures/chunantes/datamanagers/datamanager0/description.md diff --git a/substrabac/fixtures/chunantes/datamanagers/datamanager0/opener.py b/fixtures/chunantes/datamanagers/datamanager0/opener.py similarity index 97% rename from substrabac/fixtures/chunantes/datamanagers/datamanager0/opener.py rename to fixtures/chunantes/datamanagers/datamanager0/opener.py index f7dbd14d2..134325b2d 100644 --- a/substrabac/fixtures/chunantes/datamanagers/datamanager0/opener.py +++ b/fixtures/chunantes/datamanagers/datamanager0/opener.py @@ -78,7 +78,7 @@ def fake_y(self): """Make and return the ISIC like labels as np arrays.""" return np.eye(CLASSES)[np.arange(n_sample) % CLASSES].astype('uint8') - def save_pred(self, y_pred, path): + def save_predictions(self, y_pred, path): """Save prediction in path :param y_pred: predicted target variable vector @@ -91,7 +91,7 @@ def save_pred(self, y_pred, path): writer = csv.writer(f) writer.writerows(y_pred) - def get_pred(self, path): + def get_predictions(self, path): """Get predictions which were saved using the save_pred function :param folder: path to the folder where the previously predicted target variable vector has been saved diff --git a/substrabac/fixtures/chunantes/datasamples/datasample0/0024899.tar.gz b/fixtures/chunantes/datasamples/datasample0/0024899.tar.gz similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/datasample0/0024899.tar.gz rename to fixtures/chunantes/datasamples/datasample0/0024899.tar.gz diff --git a/substrabac/fixtures/chunantes/datasamples/datasample0/0024899.zip b/fixtures/chunantes/datasamples/datasample0/0024899.zip similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/datasample0/0024899.zip rename to fixtures/chunantes/datasamples/datasample0/0024899.zip diff --git a/substrabac/fixtures/chunantes/datasamples/datasample1/0024700.tar.gz b/fixtures/chunantes/datasamples/datasample1/0024700.tar.gz similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/datasample1/0024700.tar.gz rename to fixtures/chunantes/datasamples/datasample1/0024700.tar.gz diff --git a/substrabac/fixtures/chunantes/datasamples/datasample1/0024700.zip b/fixtures/chunantes/datasamples/datasample1/0024700.zip similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/datasample1/0024700.zip rename to fixtures/chunantes/datasamples/datasample1/0024700.zip diff --git a/fixtures/chunantes/datasamples/train/0024306.zip b/fixtures/chunantes/datasamples/train/0024306.zip new file mode 100644 index 000000000..9e579cd4e Binary files /dev/null and b/fixtures/chunantes/datasamples/train/0024306.zip differ diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024306/IMG_0024306.jpg b/fixtures/chunantes/datasamples/train/0024306/IMG_0024306.jpg similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024306/IMG_0024306.jpg rename to fixtures/chunantes/datasamples/train/0024306/IMG_0024306.jpg diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024306/LABEL_0024306.csv b/fixtures/chunantes/datasamples/train/0024306/LABEL_0024306.csv similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024306/LABEL_0024306.csv rename to fixtures/chunantes/datasamples/train/0024306/LABEL_0024306.csv diff --git a/fixtures/chunantes/datasamples/train/0024307.zip b/fixtures/chunantes/datasamples/train/0024307.zip new file mode 100644 index 000000000..e4c7b3a09 Binary files /dev/null and b/fixtures/chunantes/datasamples/train/0024307.zip differ diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024307/IMG_0024307.jpg b/fixtures/chunantes/datasamples/train/0024307/IMG_0024307.jpg similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024307/IMG_0024307.jpg rename to fixtures/chunantes/datasamples/train/0024307/IMG_0024307.jpg diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024307/LABEL_0024307.csv b/fixtures/chunantes/datasamples/train/0024307/LABEL_0024307.csv similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024307/LABEL_0024307.csv rename to fixtures/chunantes/datasamples/train/0024307/LABEL_0024307.csv diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024308/IMG_0024308.jpg b/fixtures/chunantes/datasamples/train/0024308/IMG_0024308.jpg similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024308/IMG_0024308.jpg rename to fixtures/chunantes/datasamples/train/0024308/IMG_0024308.jpg diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024308/LABEL_0024308.csv b/fixtures/chunantes/datasamples/train/0024308/LABEL_0024308.csv similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024308/LABEL_0024308.csv rename to fixtures/chunantes/datasamples/train/0024308/LABEL_0024308.csv diff --git a/substrabac/fixtures/chunantes/datasamples/train/0024310.zip b/fixtures/chunantes/datasamples/train/0024310.zip similarity index 100% rename from substrabac/fixtures/chunantes/datasamples/train/0024310.zip rename to fixtures/chunantes/datasamples/train/0024310.zip diff --git a/substrabac/fixtures/chunantes/models/model0/model b/fixtures/chunantes/models/model0/model similarity index 100% rename from substrabac/fixtures/chunantes/models/model0/model rename to fixtures/chunantes/models/model0/model diff --git a/fixtures/chunantes/objectives/objective0/Dockerfile b/fixtures/chunantes/objectives/objective0/Dockerfile new file mode 100644 index 000000000..a9166d836 --- /dev/null +++ b/fixtures/chunantes/objectives/objective0/Dockerfile @@ -0,0 +1,7 @@ +FROM substrafoundation/substra-tools:0.0.1 + +RUN mkdir -p /sandbox/opener +WORKDIR /sandbox +COPY metrics.py . + +ENTRYPOINT ["python3", "metrics.py"] diff --git a/substrabac/fixtures/chunantes/objectives/objective0/description.md b/fixtures/chunantes/objectives/objective0/description.md similarity index 100% rename from substrabac/fixtures/chunantes/objectives/objective0/description.md rename to fixtures/chunantes/objectives/objective0/description.md diff --git a/substrabac/fixtures/chunantes/objectives/objective0/metrics.py b/fixtures/chunantes/objectives/objective0/metrics.py similarity index 57% rename from substrabac/fixtures/chunantes/objectives/objective0/metrics.py rename to fixtures/chunantes/objectives/objective0/metrics.py index 95652de06..04ace0354 100644 --- a/substrabac/fixtures/chunantes/objectives/objective0/metrics.py +++ b/fixtures/chunantes/objectives/objective0/metrics.py @@ -1,8 +1,12 @@ from sklearn.metrics import recall_score -from substratools import Metrics as MetricsABC +import substratools as tools -class Metrics(MetricsABC): +class Metrics(tools.Metrics): def score(self, y_true, y_pred): return recall_score(y_true.argmax(axis=1), y_pred.argmax(axis=1), average='macro') + + +if __name__ == '__main__': + tools.metrics.execute(Metrics()) diff --git a/fixtures/chunantes/objectives/objective0/metrics.zip b/fixtures/chunantes/objectives/objective0/metrics.zip new file mode 100644 index 000000000..911693f54 Binary files /dev/null and b/fixtures/chunantes/objectives/objective0/metrics.zip differ diff --git a/substrabac/fixtures/dataset.json b/fixtures/dataset.json similarity index 100% rename from substrabac/fixtures/dataset.json rename to fixtures/dataset.json diff --git a/substrabac/fixtures/objective.json b/fixtures/objective.json similarity index 100% rename from substrabac/fixtures/objective.json rename to fixtures/objective.json diff --git a/substrabac/fixtures/owkin/datamanagers/datamanager0/description.md b/fixtures/owkin/datamanagers/datamanager0/description.md similarity index 100% rename from substrabac/fixtures/owkin/datamanagers/datamanager0/description.md rename to fixtures/owkin/datamanagers/datamanager0/description.md diff --git a/substrabac/fixtures/owkin/datamanagers/datamanager0/opener.py b/fixtures/owkin/datamanagers/datamanager0/opener.py similarity index 97% rename from substrabac/fixtures/owkin/datamanagers/datamanager0/opener.py rename to fixtures/owkin/datamanagers/datamanager0/opener.py index 8d54fcd65..d8c0dc6ba 100644 --- a/substrabac/fixtures/owkin/datamanagers/datamanager0/opener.py +++ b/fixtures/owkin/datamanagers/datamanager0/opener.py @@ -78,7 +78,7 @@ def fake_y(self): """Make and return the ISIC like labels as np arrays.""" return np.eye(CLASSES)[np.arange(n_sample) % CLASSES].astype('uint8') - def save_pred(self, y_pred, path): + def save_predictions(self, y_pred, path): """Save prediction in path :param y_pred: predicted target variable vector @@ -91,7 +91,7 @@ def save_pred(self, y_pred, path): writer = csv.writer(f) writer.writerows(y_pred) - def get_pred(self, path): + def get_predictions(self, path): """Get predictions which were saved using the save_pred function :param folder: path to the folder where the previously predicted target variable vector has been saved diff --git a/substrabac/fixtures/owkin/datasamples/datasample0/0024315.tar.gz b/fixtures/owkin/datasamples/datasample0/0024315.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample0/0024315.tar.gz rename to fixtures/owkin/datasamples/datasample0/0024315.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/datasample1/0024701.tar.gz b/fixtures/owkin/datasamples/datasample1/0024701.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample1/0024701.tar.gz rename to fixtures/owkin/datasamples/datasample1/0024701.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/datasample2/0024318.tar.gz b/fixtures/owkin/datasamples/datasample2/0024318.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample2/0024318.tar.gz rename to fixtures/owkin/datasamples/datasample2/0024318.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/datasample3/0024317.tar.gz b/fixtures/owkin/datasamples/datasample3/0024317.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample3/0024317.tar.gz rename to fixtures/owkin/datasamples/datasample3/0024317.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/datasample4/0024900.tar.gz b/fixtures/owkin/datasamples/datasample4/0024900.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample4/0024900.tar.gz rename to fixtures/owkin/datasamples/datasample4/0024900.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/datasample4/0024900.zip b/fixtures/owkin/datasamples/datasample4/0024900.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample4/0024900.zip rename to fixtures/owkin/datasamples/datasample4/0024900.zip diff --git a/substrabac/fixtures/owkin/datasamples/datasample5/0024316.tar.gz b/fixtures/owkin/datasamples/datasample5/0024316.tar.gz similarity index 100% rename from substrabac/fixtures/owkin/datasamples/datasample5/0024316.tar.gz rename to fixtures/owkin/datasamples/datasample5/0024316.tar.gz diff --git a/substrabac/fixtures/owkin/datasamples/test/0024900.zip b/fixtures/owkin/datasamples/test/0024900.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024900.zip rename to fixtures/owkin/datasamples/test/0024900.zip diff --git a/fixtures/owkin/datasamples/test/0024900/IMG_0024900.jpg b/fixtures/owkin/datasamples/test/0024900/IMG_0024900.jpg new file mode 100644 index 000000000..086f863c3 Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024900/IMG_0024900.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024900/LABEL_0024900.csv b/fixtures/owkin/datasamples/test/0024900/LABEL_0024900.csv new file mode 100644 index 000000000..d8044bb00 --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024900/LABEL_0024900.csv @@ -0,0 +1 @@ +1.0,0.0,0.0,0.0,0.0,0.0,0.0 diff --git a/substrabac/fixtures/owkin/datasamples/test/0024901.zip b/fixtures/owkin/datasamples/test/0024901.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024901.zip rename to fixtures/owkin/datasamples/test/0024901.zip diff --git a/fixtures/owkin/datasamples/test/0024901/IMG_0024901.jpg b/fixtures/owkin/datasamples/test/0024901/IMG_0024901.jpg new file mode 100644 index 000000000..f1cb6344c Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024901/IMG_0024901.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024901/LABEL_0024901.csv b/fixtures/owkin/datasamples/test/0024901/LABEL_0024901.csv new file mode 100644 index 000000000..ff746af51 --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024901/LABEL_0024901.csv @@ -0,0 +1 @@ +0.0,1.0,0.0,0.0,0.0,0.0,0.0 diff --git a/substrabac/fixtures/owkin/datasamples/test/0024902.zip b/fixtures/owkin/datasamples/test/0024902.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024902.zip rename to fixtures/owkin/datasamples/test/0024902.zip diff --git a/fixtures/owkin/datasamples/test/0024902/IMG_0024902.jpg b/fixtures/owkin/datasamples/test/0024902/IMG_0024902.jpg new file mode 100644 index 000000000..5d11774c6 Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024902/IMG_0024902.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024902/LABEL_0024902.csv b/fixtures/owkin/datasamples/test/0024902/LABEL_0024902.csv new file mode 100644 index 000000000..ff746af51 --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024902/LABEL_0024902.csv @@ -0,0 +1 @@ +0.0,1.0,0.0,0.0,0.0,0.0,0.0 diff --git a/substrabac/fixtures/owkin/datasamples/test/0024903.zip b/fixtures/owkin/datasamples/test/0024903.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024903.zip rename to fixtures/owkin/datasamples/test/0024903.zip diff --git a/fixtures/owkin/datasamples/test/0024903/IMG_0024903.jpg b/fixtures/owkin/datasamples/test/0024903/IMG_0024903.jpg new file mode 100644 index 000000000..6525fccea Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024903/IMG_0024903.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024903/LABEL_0024903.csv b/fixtures/owkin/datasamples/test/0024903/LABEL_0024903.csv new file mode 100644 index 000000000..ff746af51 --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024903/LABEL_0024903.csv @@ -0,0 +1 @@ +0.0,1.0,0.0,0.0,0.0,0.0,0.0 diff --git a/substrabac/fixtures/owkin/datasamples/test/0024904.zip b/fixtures/owkin/datasamples/test/0024904.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024904.zip rename to fixtures/owkin/datasamples/test/0024904.zip diff --git a/fixtures/owkin/datasamples/test/0024904/IMG_0024904.jpg b/fixtures/owkin/datasamples/test/0024904/IMG_0024904.jpg new file mode 100644 index 000000000..8913cb7ad Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024904/IMG_0024904.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024904/LABEL_0024904.csv b/fixtures/owkin/datasamples/test/0024904/LABEL_0024904.csv new file mode 100644 index 000000000..8f42c719c --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024904/LABEL_0024904.csv @@ -0,0 +1 @@ +0.0,0.0,0.0,0.0,0.0,0.0,1.0 diff --git a/substrabac/fixtures/owkin/datasamples/test/0024905.zip b/fixtures/owkin/datasamples/test/0024905.zip similarity index 100% rename from substrabac/fixtures/owkin/datasamples/test/0024905.zip rename to fixtures/owkin/datasamples/test/0024905.zip diff --git a/fixtures/owkin/datasamples/test/0024905/IMG_0024905.jpg b/fixtures/owkin/datasamples/test/0024905/IMG_0024905.jpg new file mode 100644 index 000000000..6d9e6460b Binary files /dev/null and b/fixtures/owkin/datasamples/test/0024905/IMG_0024905.jpg differ diff --git a/fixtures/owkin/datasamples/test/0024905/LABEL_0024905.csv b/fixtures/owkin/datasamples/test/0024905/LABEL_0024905.csv new file mode 100644 index 000000000..ff746af51 --- /dev/null +++ b/fixtures/owkin/datasamples/test/0024905/LABEL_0024905.csv @@ -0,0 +1 @@ +0.0,1.0,0.0,0.0,0.0,0.0,0.0 diff --git a/fixtures/owkin/objectives/objective0/Dockerfile b/fixtures/owkin/objectives/objective0/Dockerfile new file mode 100644 index 000000000..a9166d836 --- /dev/null +++ b/fixtures/owkin/objectives/objective0/Dockerfile @@ -0,0 +1,7 @@ +FROM substrafoundation/substra-tools:0.0.1 + +RUN mkdir -p /sandbox/opener +WORKDIR /sandbox +COPY metrics.py . + +ENTRYPOINT ["python3", "metrics.py"] diff --git a/substrabac/fixtures/owkin/objectives/objective0/description.md b/fixtures/owkin/objectives/objective0/description.md similarity index 100% rename from substrabac/fixtures/owkin/objectives/objective0/description.md rename to fixtures/owkin/objectives/objective0/description.md diff --git a/substrabac/fixtures/owkin/objectives/objective0/metrics.py b/fixtures/owkin/objectives/objective0/metrics.py similarity index 57% rename from substrabac/fixtures/owkin/objectives/objective0/metrics.py rename to fixtures/owkin/objectives/objective0/metrics.py index 95652de06..04ace0354 100644 --- a/substrabac/fixtures/owkin/objectives/objective0/metrics.py +++ b/fixtures/owkin/objectives/objective0/metrics.py @@ -1,8 +1,12 @@ from sklearn.metrics import recall_score -from substratools import Metrics as MetricsABC +import substratools as tools -class Metrics(MetricsABC): +class Metrics(tools.Metrics): def score(self, y_true, y_pred): return recall_score(y_true.argmax(axis=1), y_pred.argmax(axis=1), average='macro') + + +if __name__ == '__main__': + tools.metrics.execute(Metrics()) diff --git a/fixtures/owkin/objectives/objective0/metrics.zip b/fixtures/owkin/objectives/objective0/metrics.zip new file mode 100644 index 000000000..1123bcefe Binary files /dev/null and b/fixtures/owkin/objectives/objective0/metrics.zip differ diff --git a/generateNetworkFile.py b/generateNetworkFile.py deleted file mode 100644 index f97e4fcff..000000000 --- a/generateNetworkFile.py +++ /dev/null @@ -1,124 +0,0 @@ -import json -import os - -dir_path = '.' - - -def generate_network_file(conf): - network_conf = {'name': 'substra', - 'description': 'Substra network', - 'version': '0.1', - 'client': {'organization': 'owkin', - 'credentialStore': {'path': '/tmp/hfc-kvs', - 'cryptoStore': { - 'path': '/tmp/hfc-cvs'}, - 'wallet': 'wallet-name'} - }, - 'organizations': {}, - 'orderers': {}, - 'peers': {}, - 'certificateAuthorities': {} - - } - for orderer in conf['orderers']: - # print(orderer) - admin_private_key = \ - os.listdir('%s/msp/keystore/' % orderer['admin_home'])[0] - network_conf['organizations'][orderer['name']] = { - 'mspid': orderer['msp_id'], - 'orderers': [orderer['host']], - 'certificateAuthorities': [orderer['ca']['name']], - 'users': {'admin': { - 'cert': '%s/msp/signcerts/cert.pem' % orderer['admin_home'], - 'private_key': '%s/msp/keystore/%s' % ( - orderer['admin_home'], admin_private_key)} - } - } - network_conf['orderers'][orderer['name']] = { - 'url': '%s:%s' % (orderer['host'], orderer['port']), - 'grpcOptions': {'grpc.ssl_target_name_override': orderer['host'], - 'grpc-max-send-message-length': 15 - }, - 'tlsCACerts': {'path': orderer['ca']['certfile']}, - 'clientKey': {'path': orderer['tls']['key']}, - 'clientCert': {'path': orderer['tls']['cert']}, - } - - network_conf['certificateAuthorities'][orderer['ca']['name']] = { - 'url': '%s:%s' % ( - orderer['ca']['host'], orderer['ca']['host_port']), - 'grpcOptions': {'verify': True}, - 'tlsCACerts': {'path': orderer['ca']['certfile']}, - 'registrar': [{'enrollId': orderer['users']['admin']['name'], - 'enrollSecret': orderer['users']['admin']['pass'] - }] - } - - for org in conf['orgs']: - # print(org) - admin_private_key = \ - os.listdir('%s/msp/keystore/' % org['users']['admin']['home'])[0] - user_private_key = \ - os.listdir('%s/msp/keystore/' % org['users']['user']['home'])[0] - network_conf['organizations'][org['name']] = {'mspid': org['msp_id'], - 'peers': [peer['host'] - for peer in - org['peers']], - 'certificateAuthorities': [ - org['ca']['name']], - 'users': {'admin': { - 'cert': '%s/msp/signcerts/cert.pem' % - org['users'][ - 'admin'][ - 'home'], - 'private_key': '%s/msp/keystore/%s' % ( - org['users'][ - 'admin'][ - 'home'], - admin_private_key)}, - 'user': { - 'cert': '%s/msp/signcerts/cert.pem' % - org[ - 'users'][ - 'user'][ - 'home'], - 'private_key': '%s/msp/keystore/%s' % ( - org['users'][ - 'user'][ - 'home'], - user_private_key)} - } - } - - network_conf['certificateAuthorities'][org['ca']['name']] = { - 'url': '%s:%s' % (org['ca']['host'], org['ca']['host_port']), - 'grpcOptions': {'verify': True}, - 'tlsCACerts': {'path': org['ca']['certfile']}, - 'registrar': [{'enrollId': org['users']['admin']['name'], - 'enrollSecret': org['users']['admin']['pass'] - }] - } - - for peer in org['peers']: - network_conf['peers'][peer['host']] = { - 'url': '%s:%s' % (peer['host'], peer['host_port']), - 'eventUrl': '%s:%s' % (peer['host'], peer['host_event_port']), - 'grpcOptions': { - 'grpc.ssl_target_name_override': peer['host'], - 'grpc.http2.keepalive_time': 15, - }, - 'tlsCACerts': {'path': org['ca']['certfile']}, - 'clientKey': {'path': peer['tls']['clientKey']}, - 'clientCert': {'path': peer['tls']['clientCert']}, - } - - with open(os.path.join(dir_path, 'network.json'), 'w') as outfile: - json.dump(network_conf, outfile, indent=4, sort_keys=True) - - return network_conf - - -if __name__ == "__main__": - conf_path = '/substra/conf/conf.json' - conf = json.load(open(conf_path, 'r')) - generate_network_file(conf) diff --git a/populate.py b/populate.py new file mode 100644 index 000000000..992be8973 --- /dev/null +++ b/populate.py @@ -0,0 +1,442 @@ +import argparse +import os +import json +import shutil +import tempfile +import time +import zipfile +import logging + +import substra + +from termcolor import colored + +logging.basicConfig(filename='populate.log', + format='[%(asctime)-15s: %(levelname)s] %(message)s') + +dir_path = os.path.dirname(os.path.realpath(__file__)) + +USER, PASSWORD = ('admin', 'admin') +SUBSTRA_FOLDER = os.getenv('SUBSTRA_PATH', '/substra') +server_path = f'{SUBSTRA_FOLDER}/servermedias' + +client = substra.Client() + + +PUBLIC_PERMISSIONS = {'public': True, 'authorized_ids': []} + + +def setup_config(network='docker'): + print('Init config for owkin and chunantes') + if network == 'docker': + # get first available user + client.add_profile('owkin', 'substra', 'p@$swr0d44', 'http://substra-backend.owkin.xyz:8000', '0.0') + client.add_profile('chunantes', 'substra', 'p@$swr0d45', 'http://substra-backend.chunantes.xyz:8001', '0.0') + client.add_profile('clb', 'substra', 'p@$swr0d46', 'http://substra-backend.clb.xyz:8002', '0.0') + if network == 'skaffold': + # the usernames and passwords are defined in the skaffold.yaml file + client.add_profile('owkin', 'node-1', 'p@$swr0d44', 'http://substra-backend.node-1.com', '0.0') + client.add_profile('chunantes', 'node-2', 'p@$swr0d45', 'http://substra-backend.node-2.com', '0.0') + client.add_profile('clb', 'node-3', 'p@$swr0d46', 'http://substra-backend.node-3.com', '0.0') + + +def zip_folder(path, destination): + zipf = zipfile.ZipFile(destination, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(path): + for f in files: + abspath = os.path.join(root, f) + archive_path = os.path.relpath(abspath, start=path) + zipf.write(abspath, arcname=archive_path) + zipf.close() + + +def get_or_create(data, profile, asset, local=True): + + client.set_profile(profile) + + method_kwargs = {} + if not local: + method_kwargs['local'] = False + + method = getattr(client, f'add_{asset}') + + try: + r = method(data, **method_kwargs) + + except substra.exceptions.AlreadyExists as e: + print(colored(e, 'cyan')) + key_or_keys = e.pkhash + + else: + print(colored(json.dumps(r, indent=2), 'green')) + + key_or_keys = [x.get('pkhash', x.get('key')) + for x in r] if isinstance(r, list) else r.get('pkhash', r.get('key')) + + return key_or_keys + + +def update_datamanager(data_manager_key, data, profile): + client.set_profile(profile) + try: + r = client.update_dataset(data_manager_key, data) + + except substra.exceptions.InvalidRequest as e: + # FIXME if the data manager is already associated with the objective + # backend answer with a 400 and a raw error coming from the + # ledger. + # this case will be handled soon, with the fabric SDK. + print(colored(str(e), 'red')) + + else: + print(colored(json.dumps(r, indent=2), 'green')) + + +def login(*args): + for org in args: + print(f'Login with {org}') + client.set_profile(org) + try: + client.login() + except Exception as e: + raise Exception(f'login failed: {str(e)}') + + +def do_populate(): + + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group() + group.add_argument('-o', '--one-org', action='store_const', dest='nb_org', const=1, + help='Launch populate with one org') + group.add_argument('-tw', '--two-orgs', action='store_const', dest='nb_org', const=2, + help='Launch populate with two orgs') + group.add_argument('-th', '--three-orgs', action='store_const', dest='nb_org', const=3, + help='Launch populate with three orgs') + parser.add_argument('-a', '--archive', action='store_true', + help='Launch populate with archive data samples only') + parser.add_argument('-s', '--skaffold', action='store_true', + help='Launch populate with skaffold (K8S) network') + parser.set_defaults(nb_org=2) + args = vars(parser.parse_args()) + + network_type = 'skaffold' if args['skaffold'] else 'docker' + setup_config(network_type) + + if args['nb_org'] == 1: + org_0 = org_1 = org_2 = 'owkin' + elif args['nb_org'] == 2: + org_0 = org_2 = 'owkin' + org_1 = 'chunantes' + elif args['nb_org'] == 3: + org_0 = 'owkin' + org_1 = 'chunantes' + org_2 = 'clb' + else: + raise Exception(f"Number of orgs {args['nb_org']} not in [1, 2, 3]") + + login(org_0, org_1, org_2) + + print(f'will create datamanager with {org_1}') + # create datamanager with org1 + data = { + 'name': 'ISIC 2018', + 'data_opener': os.path.join(dir_path, './fixtures/chunantes/datamanagers/datamanager0/opener.py'), + 'type': 'Images', + 'description': os.path.join(dir_path, './fixtures/chunantes/datamanagers/datamanager0/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + data_manager_org1_key = get_or_create(data, org_1, 'dataset') + + #################################################### + + train_data_sample_keys = [] + + if not args['archive']: + print(f'register train data (from server) on datamanager {org_1} (will take datamanager creator as worker)') + data_samples_path = ['./fixtures/chunantes/datasamples/train/0024306', + './fixtures/chunantes/datasamples/train/0024307', + './fixtures/chunantes/datasamples/train/0024308'] + for d in data_samples_path: + try: + shutil.copytree(os.path.join(dir_path, d), + os.path.join(server_path, d)) + except FileExistsError: + pass + data = { + 'paths': [os.path.join(server_path, d) for d in data_samples_path], + 'data_manager_keys': [data_manager_org1_key], + 'test_only': False, + } + train_data_sample_keys = get_or_create(data, org_1, 'data_samples', local=False) + else: + print(f'register train data on datamanager {org_1} (will take datamanager creator as worker)') + data = { + 'paths': [ + os.path.join(dir_path, './fixtures/chunantes/datasamples/train/0024306'), + os.path.join(dir_path, './fixtures/chunantes/datasamples/train/0024307'), + os.path.join(dir_path, './fixtures/chunantes/datasamples/train/0024308') + ], + 'data_manager_keys': [data_manager_org1_key], + 'test_only': False, + } + train_data_sample_keys = get_or_create(data, org_1, 'data_samples') + + #################################################### + + print(f'create datamanager, test data and objective on {org_0}') + data = { + 'name': 'Simplified ISIC 2018', + 'data_opener': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/opener.py'), + 'type': 'Images', + 'description': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + data_manager_org0_key = get_or_create(data, org_0, 'dataset') + + print(f'create datamanager, test data and objective on {org_1} (should say "already exists")') + data = { + 'name': 'Simplified ISIC 2018', + 'data_opener': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/opener.py'), + 'type': 'Images', + 'description': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + get_or_create(data, org_1, 'dataset') + + #################################################### + + print('register test data') + data = { + 'paths': [ + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024900'), + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024901') + ], + 'data_manager_keys': [data_manager_org0_key], + 'test_only': True, + } + test_data_sample_keys = get_or_create(data, org_0, 'data_samples') + + #################################################### + + print('register test data 2') + data = { + 'paths': [ + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024902'), + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024903') + ], + 'data_manager_keys': [data_manager_org0_key], + 'test_only': True, + } + get_or_create(data, org_0, 'data_samples') + + #################################################### + + print('register test data 3') + data = { + 'paths': [ + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024904'), + os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024905') + ], + 'data_manager_keys': [data_manager_org0_key], + 'test_only': True, + } + get_or_create(data, org_0, 'data_samples') + + #################################################### + + with tempfile.TemporaryDirectory() as tmp_dir: + print('register objective') + objective_path = os.path.join( + dir_path, './fixtures/chunantes/objectives/objective0/') + + zip_path = os.path.join(tmp_dir, 'metrics.zip') + zip_folder(objective_path, zip_path) + data = { + 'name': 'Skin Lesion Classification Objective', + 'description': os.path.join(dir_path, './fixtures/chunantes/objectives/objective0/description.md'), + 'metrics_name': 'macro-average recall', + 'metrics': zip_path, + 'permissions': PUBLIC_PERMISSIONS, + 'test_data_sample_keys': test_data_sample_keys, + 'test_data_manager_key': data_manager_org0_key + } + + objective_key = get_or_create(data, org_0, 'objective') + + #################################################### + + print('register objective without data manager and data sample') + objective_path = os.path.join( + dir_path, './fixtures/chunantes/objectives/objective0/') + + zip_path = os.path.join(tmp_dir, 'metrics2.zip') + zip_folder(objective_path, zip_path) + data = { + 'name': 'Skin Lesion Classification Objective', + 'description': os.path.join(dir_path, './fixtures/owkin/objectives/objective0/description.md'), + 'metrics_name': 'macro-average recall', + 'metrics': zip_path, + 'permissions': PUBLIC_PERMISSIONS, + } + + get_or_create(data, org_0, 'objective') + + #################################################### + + # update datamanager + print('update datamanager') + data = { + 'objective_key': objective_key + } + update_datamanager(data_manager_org1_key, data, org_0) + + #################################################### + + # register algo + print('register algo') + data = { + 'name': 'Logistic regression', + 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo3/algo.tar.gz'), + 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo3/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + algo_key = get_or_create(data, org_2, 'algo') + + #################################################### + + print('register algo 2') + data = { + 'name': 'Neural Network', + 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo0/algo.tar.gz'), + 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo0/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + algo_key_2 = get_or_create(data, org_1, 'algo') + + #################################################### + + data = { + 'name': 'Random Forest', + 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo4/algo.tar.gz'), + 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo4/description.md'), + 'permissions': PUBLIC_PERMISSIONS, + } + algo_key_3 = get_or_create(data, org_1, 'algo') + + #################################################### + + # create traintuple + print('create traintuple') + data = { + 'algo_key': algo_key, + 'objective_key': objective_key, + 'data_manager_key': data_manager_org1_key, + 'train_data_sample_keys': train_data_sample_keys[:2] + # This traintuple should succeed. + # It doesn't have a tag, so it can be used as a test + # of the "non-bundled" display in substra-frontend. + } + traintuple_key = get_or_create(data, org_1, 'traintuple') + + print('create second traintuple') + data = { + 'algo_key': algo_key_2, + 'data_manager_key': data_manager_org1_key, + 'objective_key': objective_key, + 'train_data_sample_keys': train_data_sample_keys[:2], + 'tag': '(should fail) My super tag' + } + + get_or_create(data, org_1, 'traintuple') + + print('create third traintuple') + data = { + 'algo_key': algo_key_3, + 'data_manager_key': data_manager_org1_key, + 'objective_key': objective_key, + 'train_data_sample_keys': train_data_sample_keys[:2], + 'tag': '(should fail) My other tag' + } + + get_or_create(data, org_1, 'traintuple') + + #################################################### + + client.set_profile(org_1) + res = client.get_traintuple(traintuple_key) + print(colored(json.dumps(res, indent=2), 'green')) + + # create testtuple + print('create testtuple') + data = { + 'traintuple_key': traintuple_key, + 'tag': 'substra', + } + + testtuple_key = get_or_create(data, org_1, 'testtuple') + + client.set_profile(org_1) + res_t = client.get_testtuple(testtuple_key) + print(colored(json.dumps(res_t, indent=2), 'yellow')) + + testtuple_status = None + traintuple_status = None + + client.set_profile(org_1) + + while traintuple_status not in ('done', 'failed') or testtuple_status not in ('done', 'failed'): + res = client.get_traintuple(traintuple_key) + res_t = client.get_testtuple(testtuple_key) + if traintuple_status != res['status'] or testtuple_status != res_t['status']: + traintuple_status = res['status'] + testtuple_status = res_t['status'] + print('') + print('-' * 100) + print(colored(json.dumps(res, indent=2), 'green')) + print(colored(json.dumps(res_t, indent=2), 'yellow')) + else: + print('.', end='', flush=True) + + time.sleep(3) + + #################################################### + # Compute plan + + print('create compute plan') + traintuples_data = [ + { + "data_manager_key": data_manager_org1_key, + "train_data_sample_keys": [train_data_sample_keys[0], train_data_sample_keys[2]], + "traintuple_id": "dummy_traintuple_id", + "in_models_ids": [], + "tag": "", + }, + ] + testtuples_data = [ + # { + # "traintuple_id": "dummy_traintuple_id", + # "tag": "", + # } + ] + compute_plan_data = { + "algo_key": algo_key, # logistic regression, org2 + "objective_key": objective_key, # org 0 + "traintuples": traintuples_data, + "testtuples": testtuples_data, + } + # until both chaincode, backend and sdk can handle compute plan collisions, we need to have a + # generic try-except so that this script can run multiple times in a row + try: + client.set_profile(org_1) + res = client.add_compute_plan(compute_plan_data) + print(colored(json.dumps(res, indent=2), 'green')) + except: # noqa: E722 + print(colored('Could not create compute plan', 'red')) + + +if __name__ == '__main__': + try: + do_populate() + except substra.exceptions.HTTPError as e: + print(colored(str(e), 'red')) + exit(1) diff --git a/substrabac/scripts/clean_media.sh b/scripts/clean_media.sh similarity index 100% rename from substrabac/scripts/clean_media.sh rename to scripts/clean_media.sh diff --git a/substrabac/scripts/clean_media_local.sh b/scripts/clean_media_local.sh old mode 100644 new mode 100755 similarity index 69% rename from substrabac/scripts/clean_media_local.sh rename to scripts/clean_media_local.sh index d2015af2d..0732d3cdb --- a/substrabac/scripts/clean_media_local.sh +++ b/scripts/clean_media_local.sh @@ -4,4 +4,4 @@ BASEDIR="$(dirname $(dirname $0))" echo $BASEDIR # clean medias -rm -rf ${BASEDIR}/medias/* +rm -rf ${BASEDIR}/backend/medias/* diff --git a/substrabac/scripts/load_fixtures.sh b/scripts/load_fixtures.sh similarity index 100% rename from substrabac/scripts/load_fixtures.sh rename to scripts/load_fixtures.sh diff --git a/scripts/populate_db.sh b/scripts/populate_db.sh new file mode 100755 index 000000000..0da71ef41 --- /dev/null +++ b/scripts/populate_db.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# load dumps +BASEDIR=$(dirname "$0") +psql -U ${USER} -d backend_chunantes < ${BASEDIR}/../fixtures/dump_backend_chunantes.sql +psql -U ${USER} -d backend_owkin < ${BASEDIR}/../fixtures/dump_backend_owkin.sql diff --git a/scripts/recreate_db.sh b/scripts/recreate_db.sh new file mode 100755 index 000000000..966ec9190 --- /dev/null +++ b/scripts/recreate_db.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +dropdb -U ${USER} backend +createdb -U ${USER} -E UTF8 backend +psql -U ${USER} -d backend -c "CREATE USER backend WITH PASSWORD 'backend' CREATEDB CREATEROLE SUPERUSER;" + +dropdb -U ${USER} backend_owkin +createdb -U ${USER} -E UTF8 backend_owkin +psql -U ${USER} -d backend_owkin -c "GRANT ALL PRIVILEGES ON DATABASE backend_owkin to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" + +dropdb -U ${USER} backend_chunantes +createdb -U ${USER} -E UTF8 backend_chunantes +psql -U ${USER} -d backend_chunantes -c "GRANT ALL PRIVILEGES ON DATABASE backend_chunantes to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" + +dropdb -U ${USER} backend_clb +createdb -U ${USER} -E UTF8 backend_clb +psql -U ${USER} -d backend_clb -c "GRANT ALL PRIVILEGES ON DATABASE backend_chunantes to backend;ALTER ROLE backend WITH SUPERUSER CREATEROLE CREATEDB;" diff --git a/skaffold.yaml b/skaffold.yaml new file mode 100644 index 000000000..4bb71cee6 --- /dev/null +++ b/skaffold.yaml @@ -0,0 +1,118 @@ +apiVersion: skaffold/v1beta13 +kind: Config +build: + artifacts: + - image: substrafoundation/substra-backend + context: . + docker: + dockerfile: docker/substra-backend/Dockerfile + + - image: substrafoundation/celerybeat + context: . + docker: + dockerfile: docker/celerybeat/Dockerfile + + - image: substrafoundation/celeryworker + context: . + docker: + dockerfile: docker/celeryworker/Dockerfile + + - image: substrafoundation/flower + context: . + docker: + dockerfile: docker/flower/Dockerfile + +deploy: + helm: + releases: + - name: substra-backend-peer-1 + chartPath: charts/substra-backend + namespace: peer-1 + imageStrategy: + helm: {} + values: + backend.image: substrafoundation/substra-backend + celerybeat.image: substrafoundation/celerybeat + celeryworker.image: substrafoundation/celeryworker + flower.image: substrafoundation/flower + overrides: + secrets: + fabricConfigmap: network-peer-1-hlf-k8s-fabric + backend: + settings: dev + defaultDomain: http://substra-backend.node-1.com + ingress: + enabled: true + hosts: + - { host: substra-backend.node-1.com, paths: ["/"] } + annotations: + kubernetes.io/ingress.class: nginx + nginx.ingress.kubernetes.io/client-body-buffer-size: 100m + nginx.ingress.kubernetes.io/proxy-body-size: 100m + organization: + name: MyPeer1 + peer: + host: network-peer-1.peer-1 + port: 7051 + mspID: MyPeer1MSP + orderer: + host: network-orderer.orderer + port: 7050 + name: MyOrderer + persistence: + hostPath: /tmp/peer-1 + incomingNodes: + - { name: MyPeer1MSP, secret: selfSecret1 } + - { name: MyPeer2MSP, secret: nodeSecret1w2 } + outgoingNodes: + - { name: MyPeer1MSP, secret: selfSecret1 } + - { name: MyPeer2MSP, secret: nodeSecret2w1 } + users: + - name: "node-1" + secret: "p@$swr0d44" + + - name: substra-backend-peer-2 + chartPath: charts/substra-backend + namespace: peer-2 + imageStrategy: + helm: {} + values: + backend.image: substrafoundation/substra-backend + celerybeat.image: substrafoundation/celerybeat + celeryworker.image: substrafoundation/celeryworker + flower.image: substrafoundation/flower + overrides: + secrets: + fabricConfigmap: network-peer-2-hlf-k8s-fabric + backend: + settings: dev + defaultDomain: http://substra-backend.node-2.com + ingress: + enabled: true + hosts: + - { host: substra-backend.node-2.com, paths: ["/"] } + annotations: + kubernetes.io/ingress.class: nginx + nginx.ingress.kubernetes.io/client-body-buffer-size: 100m + nginx.ingress.kubernetes.io/proxy-body-size: 100m + organization: + name: MyPeer2 + peer: + host: network-peer-2.peer-2 + port: 7051 + mspID: MyPeer2MSP + orderer: + host: network-orderer.orderer + port: 7050 + name: MyOrderer + persistence: + hostPath: /tmp/peer-2 + incomingNodes: + - { name: MyPeer1MSP, secret: nodeSecret2w1 } + - { name: MyPeer2MSP, secret: selfSecret2 } + outgoingNodes: + - { name: MyPeer1MSP, secret: nodeSecret1w2 } + - { name: MyPeer2MSP, secret: selfSecret2 } + users: + - name: "node-2" + secret: "p@$swr0d45" diff --git a/substrabac/base_metrics/Dockerfile b/substrabac/base_metrics/Dockerfile deleted file mode 100644 index 679578a93..000000000 --- a/substrabac/base_metrics/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM eu.gcr.io/substra-208412/substratools - -RUN mkdir -p /sandbox -RUN mkdir -p /sandbox/opener -RUN mkdir -p /sandbox/metrics -WORKDIR /sandbox - -ENTRYPOINT ["python3"] -CMD ["-c", "import substratools as tools; tools.metrics.execute()"] diff --git a/substrabac/fake_data_sample/Dockerfile b/substrabac/fake_data_sample/Dockerfile deleted file mode 100644 index a47637b76..000000000 --- a/substrabac/fake_data_sample/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM eu.gcr.io/substra-208412/substratools - -RUN apt-get update; apt-get install -y build-essential libssl-dev python3 python3-dev python3-pip -RUN pip3 install --upgrade pip -RUN pip3 install pillow numpy sklearn pandas - -RUN mkdir -p /sandbox/metrics - -WORKDIR /sandbox - -ADD ./open_data_sample.py . - -ENTRYPOINT ["python3", "open_data_sample.py"] diff --git a/substrabac/fake_data_sample/open_data_sample.py b/substrabac/fake_data_sample/open_data_sample.py deleted file mode 100644 index 3b15172b9..000000000 --- a/substrabac/fake_data_sample/open_data_sample.py +++ /dev/null @@ -1,7 +0,0 @@ -import substratools as tools - - -if __name__ == "__main__": - opener = tools.opener.load_from_module() - opener.get_X() - opener.get_y() diff --git a/substrabac/fake_metrics/Dockerfile b/substrabac/fake_metrics/Dockerfile deleted file mode 100644 index 61a90bf85..000000000 --- a/substrabac/fake_metrics/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM eu.gcr.io/substra-208412/substratools - -RUN mkdir -p /sandbox -RUN mkdir -p /sandbox/opener -RUN mkdir -p /sandbox/metrics -WORKDIR /sandbox - -ENTRYPOINT ["python3"] -CMD ["-c", "import substratools as tools; tools.metrics.execute(dry_run=True)"] diff --git a/substrabac/fixtures/chunantes/algos/algo0/description.md~ b/substrabac/fixtures/chunantes/algos/algo0/description.md~ deleted file mode 100644 index dd420e7e8..000000000 --- a/substrabac/fixtures/chunantes/algos/algo0/description.md~ +++ /dev/null @@ -1,5 +0,0 @@ -# My top unefficient algo 2 - -Set of one-vs-all logistic regression using sklearn (SGD classfier with loss=log) - -Performance are very bad, since the metrics is the macro average recall score and elements of two classes are very bad predicted... diff --git a/substrabac/fixtures/chunantes/algos/algo3/algo.tar.gz b/substrabac/fixtures/chunantes/algos/algo3/algo.tar.gz deleted file mode 100644 index 8805ccf0b..000000000 Binary files a/substrabac/fixtures/chunantes/algos/algo3/algo.tar.gz and /dev/null differ diff --git a/substrabac/fixtures/chunantes/algos/algo4/description.md~ b/substrabac/fixtures/chunantes/algos/algo4/description.md~ deleted file mode 100644 index db4902dd5..000000000 --- a/substrabac/fixtures/chunantes/algos/algo4/description.md~ +++ /dev/null @@ -1,5 +0,0 @@ -# My top unefficient algo 3 - -Set of one-vs-all logistic regression using sklearn (SGD classfier with loss=log) - -Performance are very bad, since the metrics is the macro average recall score and elements of two classes are very bad predicted... diff --git a/substrabac/fixtures/isic_2018.py b/substrabac/fixtures/isic_2018.py deleted file mode 100644 index e76b84bf3..000000000 --- a/substrabac/fixtures/isic_2018.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import json -from subprocess import PIPE, Popen as popen -import time - -from django.conf import settings - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -# Use substra shell SDK -try: - popen(['substra'], stdout=PIPE).communicate()[0] -except: - print('Substrabac SDK is not installed, please run pip install git+https://github.com/SubstraFoundation/substrabacSDK.git@master') -else: - print('Init config in /tmp/.substrabac for owkin and chunantes') - username = "owkestra" - password = "owkestrapwd" - auth = [] - if username is not None and password is not None: - auth = [username, password] - res = popen(['substra', 'config', 'https://substra.owkin.com:9000', '0.0', '--profile=owkin', '--config=/tmp/.substrabac'] + auth, stdout=PIPE).communicate()[0] - - print('create data manager with owkin org') - # create data manager with owkin org - data = json.dumps({ - "name": "ISIC 2018", - "data_opener": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/dataset/isic2018/opener.py", - "type": "Images", - "description": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/dataset/isic2018/description.md", - "permissions": "all", - "challenge_keys": [] - }) - - res = popen(['substra', 'add', 'datamanager', '--profile=owkin', '--config=/tmp/.substrabac', data], - stdout=PIPE).communicate()[0] - res_data = json.loads(res.decode('utf-8')) - datamanager_key = res_data['pkhash'] - print(json.dumps(res_data, indent=2)) - - # Register Data on substrabac docker - # python3 manage.py bulkcreatedata /substra/datasets/isic2018/train_data.json; python3 manage.py bulkcreatedata /substra/datasets/isic2018/test_data.json - - print('You have to register data manually') - input("When it is done, press Enter to continue...") - - # register objective - print('register objective') - data = json.dumps({ - "name": "Skin Lesion Classification Objective", - "description": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/description.md", - "metrics_name": "macro-average recall", - "metrics": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/metrics.py", - "permissions": "all", - "test_data_sample_keys": ["039eecf8279c570022f000984d91e175ca8efbf858f11b8bffc88d91ccb51096"] - }) - - res = popen(['substra', 'add', 'objective', '--profile=owkin', '--config=/tmp/.substrabac', data], - stdout=PIPE).communicate()[0] - res_data = json.loads(res.decode('utf-8')) - objective_key = res_data['pkhash'] - print(json.dumps(res_data, indent=2)) - - # ############################ - - # register algo - print('register algo') - data = json.dumps({ - "name": "CNN Classifier GPU Updated", - "file": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/algo/algo.tar.gz", - "description": "/Users/kelvin/Substra/substra-challenge/skin-lesion-classification/algo/description.md", - "objective_key": objective_key, - "permissions": "all", - }) - - res = popen(['substra', 'add', 'algo', '--profile=owkin', '--config=/tmp/.substrabac', data], - stdout=PIPE).communicate()[0] - res_data = json.loads(res.decode('utf-8')) - algo_key = res_data['pkhash'] - print(json.dumps(res_data, indent=2)) - - # #################################### - - # create traintuple - print('create traintuple') - data = json.dumps({ - "algo_key": algo_key, - "model_key": "", - "train_data_keys": ["33d577a1dbbf95c9cfccc4853ad7ca369b535243053f84a206308ad46e89aa59"] - }) - - res = popen(['substra', 'add', 'traintuple', '--profile=owkin', '--config=/tmp/.substrabac', data], - stdout=PIPE).communicate()[0] - res_data = json.loads(res.decode('utf-8')) - trainuple_key = res_data['pkhash'] - print(json.dumps(res_data, indent=2)) - - # Check traintuple - res = popen(['substra', 'get', 'traintuple', trainuple_key, '--profile=owkin', '--config=/tmp/.substrabac'], - stdout=PIPE).communicate()[0] - res = json.loads(res.decode('utf-8')) - print(json.dumps(res, indent=2)) - while res['status'] != 'done': - res = popen(['substra', 'get', 'traintuple', trainuple_key, '--profile=owkin', '--config=/tmp/.substrabac'], - stdout=PIPE).communicate()[0] - res = json.loads(res.decode('utf-8')) - print(json.dumps(res, indent=2)) - time.sleep(3) diff --git a/substrabac/libs/BasicAuthMiddleware.py b/substrabac/libs/BasicAuthMiddleware.py deleted file mode 100644 index 198e2665b..000000000 --- a/substrabac/libs/BasicAuthMiddleware.py +++ /dev/null @@ -1,50 +0,0 @@ -from django.core.exceptions import MiddlewareNotUsed -from django.http import HttpResponse -from django.conf import settings - -import base64 - - -html_template = """ - - Auth required - -

Authorization Required

- - -""" - - -class BasicAuthMiddleware: - def unauthed(self): - response = HttpResponse(html_template, content_type="text/html") - response['WWW-Authenticate'] = 'Basic realm="Administrator area"' - response.status_code = 401 - return response - - def __init__(self, get_response): - if settings.DEBUG: - raise MiddlewareNotUsed - - self.get_response = get_response - # One-time configuration and initialization. - - def __call__(self, request): - username = getattr(settings, 'BASICAUTH_USERNAME', None) - password = getattr(settings, 'BASICAUTH_PASSWORD', None) - - if username not in (None, '') and password not in (None, ''): - if request.method != 'OPTIONS': - if 'HTTP_AUTHORIZATION' not in request.META: - return self.unauthed() - else: - authentication = request.META['HTTP_AUTHORIZATION'] - (authmeth, auth) = authentication.split(' ', 1) - if 'basic' != authmeth.lower(): - return self.unauthed() - auth = base64.b64decode(auth.strip()).decode('utf-8') - username, password = auth.split(':', 1) - if username != settings.BASICAUTH_USERNAME or password != settings.BASICAUTH_PASSWORD: - return self.unauthed() - - return self.get_response(request) diff --git a/substrabac/libs/__init__.py b/substrabac/libs/__init__.py deleted file mode 100644 index ca5d1ad2d..000000000 --- a/substrabac/libs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__author__ = 'guillaume' diff --git a/substrabac/populate.py b/substrabac/populate.py deleted file mode 100644 index 20cb01706..000000000 --- a/substrabac/populate.py +++ /dev/null @@ -1,337 +0,0 @@ -import argparse -import os -import json -import shutil -import time - -import substra_sdk_py as substra - -from termcolor import colored - -dir_path = os.path.dirname(os.path.realpath(__file__)) -server_path = '/substra/servermedias' - -client = substra.Client() - - -def setup_config(): - print('Init config in /tmp/.substrabac for owkin and chunantes') - client.create_config('owkin', 'http://owkin.substrabac:8000', '0.0') - client.create_config('chunantes', 'http://chunantes.substrabac:8001', '0.0') - client.create_config('clb', 'http://clb.substrabac:8002', '0.0') - - -def get_or_create(data, profile, asset, dryrun=False, register=False): - - client.set_config(profile) - - method = client.add if not register else client.register - - if dryrun: - print('dryrun') - try: - r = method(asset, data, dryrun=True) - except substra.exceptions.AlreadyExists as e: - r = e.response.json() - print(colored(json.dumps(r, indent=2), 'cyan')) - else: - print(colored(json.dumps(r, indent=2), 'magenta')) - - print('real') - try: - r = method(asset, data) - - except substra.exceptions.AlreadyExists as e: - r = e.response.json() - print(colored(json.dumps(r, indent=2), 'cyan')) - key_or_keys = e.pkhash - - else: - print(colored(json.dumps(r, indent=2), 'green')) - key_or_keys = [x['pkhash'] for x in r] if isinstance(r, list) else r['pkhash'] - - return key_or_keys - - -def update_datamanager(data_manager_key, data, profile): - client.set_config(profile) - try: - r = client.update('data_manager', data_manager_key, data) - - except substra.exceptions.AlreadyExists as e: - r = e.response.json() - print(colored(json.dumps(r, indent=2), 'cyan')) - - except substra.exceptions.InvalidRequest as e: - # FIXME if the data manager is already associated with the objective - # backend answer with a 400 and a raw error coming from the - # ledger. - # this case will be handled soon, with the fabric SDK. - print(colored(str(e), 'red')) - - else: - print(colored(json.dumps(r, indent=2), 'green')) - - -def do_populate(): - setup_config() - - parser = argparse.ArgumentParser() - group = parser.add_mutually_exclusive_group() - group.add_argument('-o', '--one-org', action='store_const', dest='nb_org', const=1, - help='Launch populate with one org') - group.add_argument('-tw', '--two-orgs', action='store_const', dest='nb_org', const=2, - help='Launch populate with two orgs') - group.add_argument('-th', '--three-orgs', action='store_const', dest='nb_org', const=3, - help='Launch populate with three orgs') - parser.set_defaults(nb_org=2) - args = vars(parser.parse_args()) - - if args['nb_org'] == 1: - org_0 = org_1 = org_2 = 'owkin' - elif args['nb_org'] == 2: - org_0 = org_2 = 'owkin' - org_1 = 'chunantes' - elif args['nb_org'] == 3: - org_0 = 'owkin' - org_1 = 'chunantes' - org_2 = 'clb' - else: - raise Exception(f"Number of orgs {args['nb_org']} not in [1, 2, 3]") - - print(f'will create datamanager with {org_1}') - # create datamanager with org1 - data = { - 'name': 'ISIC 2018', - 'data_opener': os.path.join(dir_path, './fixtures/chunantes/datamanagers/datamanager0/opener.py'), - 'type': 'Images', - 'description': os.path.join(dir_path, './fixtures/chunantes/datamanagers/datamanager0/description.md'), - 'permissions': 'all', - } - data_manager_org1_key = get_or_create(data, org_1, 'data_manager', dryrun=True) - - #################################################### - - train_data_sample_keys = [] - print(f'register train data (from server) on datamanager {org_1} (will take datamanager creator as worker)') - data_samples_path = ['./fixtures/chunantes/datasamples/train/0024306', - './fixtures/chunantes/datasamples/train/0024307'] - for d in data_samples_path: - try: - shutil.copytree(os.path.join(dir_path, d), - os.path.join(server_path, d)) - except FileExistsError: - pass - data = { - 'paths': [os.path.join(server_path, d) for d in data_samples_path], - 'data_manager_keys': [data_manager_org1_key], - 'test_only': False, - } - train_data_sample_keys = get_or_create(data, org_1, 'data_sample', dryrun=True, register=True) - - #################################################### - - print(f'create datamanager, test data and objective on {org_0}') - data = { - 'name': 'Simplified ISIC 2018', - 'data_opener': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/opener.py'), - 'type': 'Images', - 'description': os.path.join(dir_path, './fixtures/owkin/datamanagers/datamanager0/description.md'), - 'permissions': 'all' - } - data_manager_org0_key = get_or_create(data, org_0, 'data_manager') - - #################################################### - - print('register test data') - data = { - 'paths': [ - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024900.zip'), - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024901.zip') - ], - 'data_manager_keys': [data_manager_org0_key], - 'test_only': True, - } - test_data_sample_keys = get_or_create(data, org_0, 'data_sample') - - #################################################### - - print('register test data 2') - data = { - 'paths': [ - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024902.zip'), - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024903.zip') - ], - 'data_manager_keys': [data_manager_org0_key], - 'test_only': True, - } - get_or_create(data, org_0, 'data_sample') - - #################################################### - - print('register test data 3') - data = { - 'paths': [ - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024904.zip'), - os.path.join(dir_path, './fixtures/owkin/datasamples/test/0024905.zip') - ], - 'data_manager_keys': [data_manager_org0_key], - 'test_only': True, - } - get_or_create(data, org_0, 'data_sample') - - #################################################### - - print('register objective') - data = { - 'name': 'Skin Lesion Classification Objective', - 'description': os.path.join(dir_path, './fixtures/chunantes/objectives/objective0/description.md'), - 'metrics_name': 'macro-average recall', - 'metrics': os.path.join(dir_path, './fixtures/chunantes/objectives/objective0/metrics.py'), - 'permissions': 'all', - 'test_data_sample_keys': test_data_sample_keys, - 'test_data_manager_key': data_manager_org0_key - } - - objective_key = get_or_create(data, org_0, 'objective', dryrun=True) - - #################################################### - - print('register objective without data manager and data sample') - data = { - 'name': 'Skin Lesion Classification Objective', - 'description': os.path.join(dir_path, './fixtures/owkin/objectives/objective0/description.md'), - 'metrics_name': 'macro-average recall', - 'metrics': os.path.join(dir_path, './fixtures/owkin/objectives/objective0/metrics.py'), - 'permissions': 'all' - } - - get_or_create(data, org_0, 'objective', dryrun=True) - - #################################################### - - # update datamanager - print('update datamanager') - data = { - 'objective_key': objective_key - } - update_datamanager(data_manager_org1_key, data, org_0) - - #################################################### - - # register algo - print('register algo') - data = { - 'name': 'Logistic regression', - 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo3/algo.tar.gz'), - 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo3/description.md'), - 'permissions': 'all', - } - algo_key = get_or_create(data, org_2, 'algo') - - #################################################### - - print('register algo 2') - data = { - 'name': 'Neural Network', - 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo0/algo.tar.gz'), - 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo0/description.md'), - 'permissions': 'all', - } - algo_key_2 = get_or_create(data, org_1, 'algo') - - #################################################### - - data = { - 'name': 'Random Forest', - 'file': os.path.join(dir_path, './fixtures/chunantes/algos/algo4/algo.tar.gz'), - 'description': os.path.join(dir_path, './fixtures/chunantes/algos/algo4/description.md'), - 'permissions': 'all', - } - algo_key_3 = get_or_create(data, org_1, 'algo') - - #################################################### - - # create traintuple - print('create traintuple') - data = { - 'algo_key': algo_key, - 'objective_key': objective_key, - 'data_manager_key': data_manager_org1_key, - 'train_data_sample_keys': train_data_sample_keys, - 'tag': 'substra' - } - traintuple_key = get_or_create(data, org_1, 'traintuple') - - print('create second traintuple') - data = { - 'algo_key': algo_key_2, - 'data_manager_key': data_manager_org1_key, - 'objective_key': objective_key, - 'train_data_sample_keys': train_data_sample_keys, - 'tag': 'My super tag' - } - - get_or_create(data, org_1, 'traintuple') - - print('create third traintuple') - data = { - 'algo_key': algo_key_3, - 'data_manager_key': data_manager_org1_key, - 'objective_key': objective_key, - 'train_data_sample_keys': train_data_sample_keys, - } - - get_or_create(data, org_1, 'traintuple') - - #################################################### - - client.set_config(org_1) - res = client.get('traintuple', traintuple_key) - print(colored(json.dumps(res, indent=2), 'green')) - - # create testtuple - print('create testtuple') - data = { - 'traintuple_key': traintuple_key - } - - testtuple_key = get_or_create(data, org_1, 'testtuple') - - client.set_config(org_1) - res_t = client.get('testtuple', testtuple_key) - print(colored(json.dumps(res_t, indent=2), 'yellow')) - - testtuple_status = None - traintuple_status = None - - client.set_config(org_1) - - while traintuple_status not in ('done', 'failed') or testtuple_status not in ('done', 'failed'): - res = client.get('traintuple', traintuple_key) - res_t = client.get('testtuple', testtuple_key) - if traintuple_status != res['status'] or testtuple_status != res_t['status']: - traintuple_status = res['status'] - testtuple_status = res_t['status'] - print('') - print('-' * 100) - print(colored(json.dumps(res, indent=2), 'green')) - print(colored(json.dumps(res_t, indent=2), 'yellow')) - else: - print('.', end='', flush=True) - - time.sleep(3) - - -if __name__ == '__main__': - try: - do_populate() - except substra.exceptions.HTTPError as e: - try: - error = e.response.json() - except Exception: - error_message = e.response.text - else: - error_message = json.dumps(error, indent=2) - print(colored(str(e), 'red')) - print(colored(error_message, 'red')) diff --git a/substrabac/scripts/generate_assets.py b/substrabac/scripts/generate_assets.py deleted file mode 100644 index ea0447a02..000000000 --- a/substrabac/scripts/generate_assets.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import json -from substra_sdk_py import Client - - -dir_path = os.path.dirname(os.path.realpath(__file__)) - - -def main(): - - client = Client() - client.create_config('owkin', 'http://owkin.substrabac:8000', '0.0') - - client.set_config('owkin') - - assets = {} - assets['objective'] = json.dumps(client.list('objective'), indent=4) - assets['datamanager'] = json.dumps(client.list('data_manager'), indent=4) - assets['algo'] = json.dumps(client.list('algo'), indent=4) - assets['traintuple'] = json.dumps(client.list('traintuple'), indent=4) - assets['testtuple'] = json.dumps(client.list('testtuple'), indent=4) - - assets['model'] = json.dumps([res for res in client.list('model') - if ('traintuple' in res and 'testtuple' in res)], indent=4) - - with open(os.path.join(dir_path, '../substrapp/tests/assets.py'), 'w') as f: - for k, v in assets.items(): - v = v.replace('owkin.substrabac:8000', 'testserver') - v = v.replace('chunantes.substrabac:8001', 'testserver') - v = v.replace('true', 'True') - v = v.replace('false', 'False') - v = v.replace('null', 'None') - f.write(f'{k} = {v}') - f.write('\n\n') - - -if __name__ == '__main__': - main() diff --git a/substrabac/scripts/populate_db.sh b/substrabac/scripts/populate_db.sh deleted file mode 100755 index e0560c955..000000000 --- a/substrabac/scripts/populate_db.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -# load dumps -BASEDIR=$(dirname "$0") -psql -U ${USER} -d substrabac_chunantes < ${BASEDIR}/../fixtures/dump_substrabac_chunantes.sql -psql -U ${USER} -d substrabac_owkin < ${BASEDIR}/../fixtures/dump_substrabac_owkin.sql diff --git a/substrabac/scripts/recreate_db.sh b/substrabac/scripts/recreate_db.sh deleted file mode 100755 index 3cf57c3b8..000000000 --- a/substrabac/scripts/recreate_db.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -dropdb -U ${USER} substrabac -createdb -U ${USER} -E UTF8 substrabac -psql -U ${USER} -d substrabac -c "CREATE USER substrabac WITH PASSWORD 'substrabac' CREATEDB CREATEROLE SUPERUSER;" - -dropdb -U ${USER} substrabac_owkin -createdb -U ${USER} -E UTF8 substrabac_owkin -psql -U ${USER} -d substrabac_owkin -c "GRANT ALL PRIVILEGES ON DATABASE substrabac_owkin to substrabac;ALTER ROLE substrabac WITH SUPERUSER CREATEROLE CREATEDB;" - -dropdb -U ${USER} substrabac_chunantes -createdb -U ${USER} -E UTF8 substrabac_chunantes -psql -U ${USER} -d substrabac_chunantes -c "GRANT ALL PRIVILEGES ON DATABASE substrabac_chunantes to substrabac;ALTER ROLE substrabac WITH SUPERUSER CREATEROLE CREATEDB;" diff --git a/substrabac/substrabac/celery.py b/substrabac/substrabac/celery.py deleted file mode 100644 index faf074883..000000000 --- a/substrabac/substrabac/celery.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import absolute_import, unicode_literals -import os -from celery import Celery - - -# set the default Django settings module for the 'celery' program. -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'substrabac.settings.prod') - -app = Celery('substrabac') - -# Using a string here means the worker doesn't have to serialize -# the configuration object to child processes. -# - namespace='CELERY' means all celery-related configuration keys -# should have a `CELERY_` prefix. -app.config_from_object('django.conf:settings', namespace='CELERY') - -# Load task modules from all registered Django app configs. -app.autodiscover_tasks() - - -@app.task(bind=True) -def debug_task(self): - print('Request: {0!r}'.format(self.request)) - - -@app.on_after_configure.connect -def setup_periodic_tasks(sender, **kwargs): - from substrapp.tasks import prepareTrainingTask, prepareTestingTask - - period = 10 - sender.add_periodic_task(period, prepareTrainingTask.s(), queue='scheduler', - name='query Traintuples to prepare train task on todo traintuples') - sender.add_periodic_task(period, prepareTestingTask.s(), queue='scheduler', - name='query Testuples to prepare test task on todo testuples') diff --git a/substrabac/substrabac/settings/deps/restframework.py b/substrabac/substrabac/settings/deps/restframework.py deleted file mode 100644 index 44b021d45..000000000 --- a/substrabac/substrabac/settings/deps/restframework.py +++ /dev/null @@ -1,15 +0,0 @@ -REST_FRAMEWORK = { - 'TEST_REQUEST_DEFAULT_FORMAT': 'json', - 'DEFAULT_RENDERER_CLASSES': ( - 'rest_framework.renderers.JSONRenderer', - #'rest_framework.renderers.AdminRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', - ), - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.SessionAuthentication', - ), - 'UNICODE_JSON': False, - 'DEFAULT_VERSIONING_CLASS': 'libs.versioning.AcceptHeaderVersioningRequired', - 'ALLOWED_VERSIONS': ('0.0',), - 'DEFAULT_VERSION': '0.0', -} diff --git a/substrabac/substrapp/exception_handler.py b/substrabac/substrapp/exception_handler.py deleted file mode 100644 index d1d59c9b1..000000000 --- a/substrabac/substrapp/exception_handler.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import uuid -import docker.errors -import traceback -import json -import re - -LANGUAGES = { - 'ShellScript': '00', - 'Python': '01' -} - -SERVICES = { - 'System': '00', - 'Docker': '01' -} - -EXCEPTION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'exceptions.json') - -EXCEPTIONS_UUID_LENGTH = 7 - -if os.path.exists(EXCEPTION_PATH): - try: - EXCEPTIONS_MAP = json.load(open(EXCEPTION_PATH)) - except: - # The json may be corrupted - EXCEPTIONS_MAP = dict() -else: - EXCEPTIONS_MAP = dict() - - -def get_exception_codes_from_docker_trace(): - container_code = EXCEPTIONS_MAP[docker.errors.ContainerError.__name__] - - # Get last line of the docker traceback which contains the traceback inside the container - docker_traceback = traceback.format_exc().splitlines()[-1].encode('utf_8').decode('unicode_escape') - docker_traceback = re.split(':| |\n', docker_traceback) - - exception_codes = [code for exception, code in EXCEPTIONS_MAP.items() - if exception in docker_traceback and code != container_code] - - return exception_codes - - -def get_exception_code(exception_type): - - service_code = SERVICES['System'] - exception_code = EXCEPTIONS_MAP.get(exception_type.__name__, '0000') # '0000' is default exception code - - # Exception inside a docker container - if docker.errors.ContainerError.__name__ in EXCEPTIONS_MAP and \ - exception_code == EXCEPTIONS_MAP[docker.errors.ContainerError.__name__]: - - exception_codes = get_exception_codes_from_docker_trace() - - if len(exception_codes) > 0: - # Take the first code in the list (may have more if multiple exceptions are raised) - service_code = SERVICES['Docker'] - exception_code = exception_codes.pop() - - return exception_code, service_code - - -def compute_error_code(exception): - exception_uuid = str(uuid.uuid4())[:EXCEPTIONS_UUID_LENGTH] - exception_code, service_code = get_exception_code(exception.__class__) - error_code = f'[{service_code}-{LANGUAGES["Python"]}-{exception_code}-{exception_uuid}]' - return error_code diff --git a/substrabac/substrapp/fixtures/model.py b/substrabac/substrapp/fixtures/model.py deleted file mode 100644 index 293df3c97..000000000 --- a/substrabac/substrapp/fixtures/model.py +++ /dev/null @@ -1,202 +0,0 @@ -fake_models = [ - { - 'algo': { - 'hash': '76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895', - 'name': 'Neural Network', - 'storageAddress': 'http://chunantes.substrabac:8001/algo/76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895/file/', - }, - 'objective': { - 'hash': '3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71', - 'metrics': { - 'hash': '750f622262854341bd44f55c1018949e9c119606ef5068bd7d137040a482a756', - 'storageAddress': 'http://chunantes.substrabac:8001/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/', - }, - }, - 'creator': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - 'outModel': { - 'hash': '30060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/30060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'key': '1bb5c8f42315914909c764545ea44e32b04c773468c439c9eb506176670ee6b8', - 'log': 'no error, ah ah ahstill no error, suprah ah ah', - 'permissions': 'all', - 'inModel': { - 'hash': '20060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/20060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'status': 'done', - 'testDataSample': { - 'keys': ['4b5152871b181d10ee774c10458c064c70710f4ba35938f10c0b7aa51f7dc010'], - 'openerHash': 'a8b7c235abb9a93742e336bd76ff7cd8ecc49f612e5cf6ea506dc10f4fd6b6f0', - 'perf': 0.20, - 'worker': '2d76419f4231cf67bdc53f569201322a4822dff152351fb468db013d484fc762', - }, - 'trainDataSample': { - 'keys': ['62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a', - '42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9'], - 'openerHash': '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7', - 'perf': 0.50, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - }, - { - 'algo': { - 'hash': '76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895', - 'name': 'Neural Network', - 'storageAddress': 'http://chunantes.substrabac:8001/algo/76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895/file/', - }, - 'objective': { - 'hash': '3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71', - 'metrics': { - 'hash': '750f622262854341bd44f55c1018949e9c119606ef5068bd7d137040a482a756', - 'storageAddress': 'http://chunantes.substrabac:8001/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/', - }, - }, - 'creator': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - 'outModel': { - 'hash': '40060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/40060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'key': '2bb5c8f42315914909c764545ea44e32b04c773468c439c9eb506176670ee6b8', - 'log': 'no error, ah ah ahstill no error, suprah ah ah', - 'permissions': 'all', - 'inModel': { - 'hash': '30060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/30060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'status': 'done', - 'testDataSample': { - 'keys': ['4b5152871b181d10ee774c10458c064c70710f4ba35938f10c0b7aa51f7dc010'], - 'openerHash': 'a8b7c235abb9a93742e336bd76ff7cd8ecc49f612e5cf6ea506dc10f4fd6b6f0', - 'perf': 0.35, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - 'trainDataSample': { - 'keys': ['62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a', - '42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9'], - 'openerHash': '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7', - 'perf': 0.70, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - }, - { - 'algo': { - 'hash': '76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895', - 'name': 'Neural Network', - 'storageAddress': 'http://chunantes.substrabac:8001/algo/76fe474d441b03e8416ab37b4950286014fb329e9317126e144342dd0e2ec895/file/', - }, - 'objective': { - 'hash': '3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71', - 'metrics': { - 'hash': '750f622262854341bd44f55c1018949e9c119606ef5068bd7d137040a482a756', - 'storageAddress': 'http://chunantes.substrabac:8001/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/', - }, - }, - 'creator': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - 'outModel': { - 'hash': '50060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/50060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'key': '3bb5c8f42315914909c764545ea44e32b04c773468c439c9eb506176670ee6b8', - 'log': 'no error, ah ah ahstill no error, suprah ah ah', - 'permissions': 'all', - 'inModel': { - 'hash': '40060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/40060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'status': 'done', - 'testDataSample': { - 'keys': ['4b5152871b181d10ee774c10458c064c70710f4ba35938f10c0b7aa51f7dc010'], - 'openerHash': 'a8b7c235abb9a93742e336bd76ff7cd8ecc49f612e5cf6ea506dc10f4fd6b6f0', - 'perf': 0.79, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - 'trainDataSample': { - 'keys': ['62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a', - '42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9'], - 'openerHash': '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7', - 'perf': 0.79, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - }, - { - 'algo': { - 'hash': '56a0e2f7e046ee948cf2ab38136f7b5ff131d0c538f8d75a97850d6fc06131df', - 'name': 'Random Forest', - 'storageAddress': 'http://chunantes.substrabac:8001/56a0e2f7e046ee948cf2ab38136f7b5ff131d0c538f8d75a97850d6fc06131df/file/', - }, - 'objective': { - 'hash': '3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71', - 'metrics': { - 'hash': '750f622262854341bd44f55c1018949e9c119606ef5068bd7d137040a482a756', - 'storageAddress': 'http://chunantes.substrabac:8001/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/', - }, - }, - 'creator': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - 'outModel': { - 'hash': '70060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/70060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'key': '4bb5c8f42315914909c764545ea44e32b04c773468c439c9eb506176670ee6b8', - 'log': 'no error, ah ah ahstill no error, suprah ah ah', - 'permissions': 'all', - 'inModel': { - 'hash': '60060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/60060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'status': 'done', - 'testDataSample': { - 'keys': ['4b5152871b181d10ee774c10458c064c70710f4ba35938f10c0b7aa51f7dc010'], - 'openerHash': 'a8b7c235abb9a93742e336bd76ff7cd8ecc49f612e5cf6ea506dc10f4fd6b6f0', - 'perf': 0.12, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - 'trainDataSample': { - 'keys': ['62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a', - '42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9'], - 'openerHash': '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7', - 'perf': 0.79, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - }, - { - 'algo': { - 'hash': '56a0e2f7e046ee948cf2ab38136f7b5ff131d0c538f8d75a97850d6fc06131df', - 'name': 'Random Forest', - 'storageAddress': 'http://chunantes.substrabac:8001/algo/56a0e2f7e046ee948cf2ab38136f7b5ff131d0c538f8d75a97850d6fc06131df/file/', - }, - 'objective': { - 'hash': '3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71', - 'metrics': { - 'hash': '750f622262854341bd44f55c1018949e9c119606ef5068bd7d137040a482a756', - 'storageAddress': 'http://chunantes.substrabac:8001/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/', - }, - }, - 'creator': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - 'outModel': { - 'hash': '80060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/80060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'key': '5bb5c8f42315914909c764545ea44e32b04c773468c439c9eb506176670ee6b8', - 'log': 'no error, ah ah ahstill no error, suprah ah ah', - 'permissions': 'all', - 'inModel': { - 'hash': '70060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568', - 'storageAddress': 'http://chunantes.substrabac:8001/model/70060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568/file/', - }, - 'status': 'done', - 'testDataSample': { - 'keys': ['4b5152871b181d10ee774c10458c064c70710f4ba35938f10c0b7aa51f7dc010'], - 'openerHash': 'a8b7c235abb9a93742e336bd76ff7cd8ecc49f612e5cf6ea506dc10f4fd6b6f0', - 'perf': 0.66, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - 'trainDataSample': { - 'keys': ['62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a', - '42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9'], - 'openerHash': '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7', - 'perf': 0.79, - 'worker': 'a3119c79a173581425cbe6e06c3034ec396ee805b60d9a34feaa3048beb0e4a9', - }, - }, -] diff --git a/substrabac/substrapp/generate_exceptions_map.py b/substrabac/substrapp/generate_exceptions_map.py deleted file mode 100644 index e80438b90..000000000 --- a/substrabac/substrapp/generate_exceptions_map.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import inspect -import json - - -# Modules to inspect -os.environ['DJANGO_SETTINGS_MODULE'] = 'substrabac.settings.prod' - -import docker.errors, requests.exceptions, celery.exceptions, tarfile, \ - django.core.exceptions, django.urls, django.db, django.http, django.db.transaction,\ - rest_framework.exceptions - -MODULES = [docker.errors, requests.exceptions, celery.exceptions, tarfile, - django.core.exceptions, django.urls, django.db, django.http, django.db.transaction, - rest_framework.exceptions] - -EXCEPTION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'exceptions.json') - - -def exception_tree(cls, exceptions_classes): - exceptions_classes.add(cls.__name__) - for subcls in cls.__subclasses__(): - exception_tree(subcls, exceptions_classes) - - -def find_exception(module): - # Exception classes in module - exceptions = [ename for ename, eclass in inspect.getmembers(module, inspect.isclass) - if issubclass(eclass, BaseException)] - - # Exception classes in submodule - for submodule in inspect.getmembers(module, inspect.ismodule): - exceptions += [ename for ename, eclass in inspect.getmembers(module, inspect.isclass) - if issubclass(eclass, BaseException)] - - return set(exceptions) - - -if __name__ == '__main__': - - exceptions_classes = set() - - # Add exceptions from modules - for errors_module in MODULES: - exceptions_classes.update(find_exception(errors_module)) - - # Add exceptions from python - exception_tree(BaseException, exceptions_classes) - - exceptions_classes = sorted(exceptions_classes) - - if os.path.exists(EXCEPTION_PATH): - # Append values to it - json_exceptions = json.load(open(EXCEPTION_PATH)) - - # get all new exceptions - exceptions_classes = [e for e in exceptions_classes if e not in json_exceptions.keys()] - - # get the last value - start_value = max(map(int, json_exceptions.values())) - - for code_exception, exception_name in enumerate(exceptions_classes, start=start_value + 1): - json_exceptions[exception_name] = f'{code_exception:04d}' - - with open(EXCEPTION_PATH, 'w') as outfile: - json.dump(json_exceptions, outfile, indent=4) - - else: - # Generate the json exceptions - json_exceptions = dict() - for code_exception, exception_name in enumerate(exceptions_classes, start=1): - json_exceptions[exception_name] = f'{code_exception:04d}' - - with open(EXCEPTION_PATH, 'w') as outfile: - json.dump(json_exceptions, outfile, indent=4) diff --git a/substrabac/substrapp/management/utils/localRequest.py b/substrabac/substrapp/management/utils/localRequest.py deleted file mode 100644 index 91ec423f4..000000000 --- a/substrabac/substrapp/management/utils/localRequest.py +++ /dev/null @@ -1,10 +0,0 @@ -from django.conf import settings - - -class LocalRequest(object): - - def is_secure(self): - return not getattr(settings, 'DEBUG') - - def get_host(self): - return getattr(settings, 'SITE_HOST') diff --git a/substrabac/substrapp/serializers/ledger/algo/util.py b/substrabac/substrapp/serializers/ledger/algo/util.py deleted file mode 100644 index 3dbd13981..000000000 --- a/substrabac/substrapp/serializers/ledger/algo/util.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import absolute_import, unicode_literals -from rest_framework import status - -from substrapp.models import Algo -from substrapp.utils import invokeLedger - - -def createLedgerAlgo(args, pkhash, sync=False): - - options = { - 'args': '{"Args":["registerAlgo", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - # if not created on ledger, delete from local db, else pass to validated true - try: - instance = Algo.objects.get(pk=pkhash) - except: - pass - else: - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - instance.delete() - else: - if st != status.HTTP_408_REQUEST_TIMEOUT: - instance.validated = True - instance.save() - # update data to return - data['validated'] = True - - return data, st diff --git a/substrabac/substrapp/serializers/ledger/datamanager/util.py b/substrabac/substrapp/serializers/ledger/datamanager/util.py deleted file mode 100644 index 8469fe460..000000000 --- a/substrabac/substrapp/serializers/ledger/datamanager/util.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import absolute_import, unicode_literals -from rest_framework import status - -from substrapp.models import DataManager -from substrapp.utils import invokeLedger - - -def createLedgerDataManager(args, pkhash, sync=False): - options = { - 'args': '{"Args":["registerDataManager", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - # if not created on ledger, delete from local db, else pass to validated true - try: - instance = DataManager.objects.get(pk=pkhash) - except: - pass - else: - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - instance.delete() - else: - if st != status.HTTP_408_REQUEST_TIMEOUT: - instance.validated = True - instance.save() - # update data to return - data['validated'] = True - - return data, st - - -def updateLedgerDataManager(args, sync=False): - options = { - 'args': '{"Args":["updateDataManager", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - return data, st diff --git a/substrabac/substrapp/serializers/ledger/datasample/util.py b/substrabac/substrapp/serializers/ledger/datasample/util.py deleted file mode 100644 index 15eacd800..000000000 --- a/substrabac/substrapp/serializers/ledger/datasample/util.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import absolute_import, unicode_literals -from rest_framework import status - -from substrapp.models import DataSample -from substrapp.utils import invokeLedger - - -def createLedgerDataSample(args, pkhashes, sync=False): - options = { - 'args': '{"Args":["registerDataSample", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - # if not created on ledger, delete from local db, else pass to validated true - try: - instances = DataSample.objects.filter(pk__in=pkhashes) - except: - pass - else: - - # delete if not created - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - instances.delete() - else: - # do not pass to true if still waiting for validation - if st != status.HTTP_408_REQUEST_TIMEOUT: - instances.update(validated=True) - # update data to return - data['validated'] = True - - return data, st - - -def updateLedgerDataSample(args, sync=False): - options = { - 'args': '{"Args":["updateDataSample", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - return data, st diff --git a/substrabac/substrapp/serializers/ledger/objective/util.py b/substrabac/substrapp/serializers/ledger/objective/util.py deleted file mode 100644 index 8b9693bb4..000000000 --- a/substrabac/substrapp/serializers/ledger/objective/util.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import absolute_import, unicode_literals -from rest_framework import status - -from substrapp.models import Objective -from substrapp.utils import invokeLedger - - -def createLedgerObjective(args, pkhash, sync=False): - - options = { - 'args': '{"Args":["registerObjective", ' + args + ']}' - } - data, st = invokeLedger(options, sync) - - # if not created on ledger, delete from local db, else pass to validated true - try: - instance = Objective.objects.get(pk=pkhash) - except: - pass - else: - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - instance.delete() - else: - if st != status.HTTP_408_REQUEST_TIMEOUT: - instance.validated = True - instance.save() - # update data to return - data['validated'] = True - - return data, st diff --git a/substrabac/substrapp/serializers/ledger/testtuple/util.py b/substrabac/substrapp/serializers/ledger/testtuple/util.py deleted file mode 100644 index b02ccfb0e..000000000 --- a/substrabac/substrapp/serializers/ledger/testtuple/util.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import absolute_import, unicode_literals - - -from substrapp.utils import invokeLedger - - -def createLedgerTesttuple(args, sync=False): - options = { - 'args': '{"Args":["createTesttuple", ' + args + ']}' - } - return invokeLedger(options, sync) diff --git a/substrabac/substrapp/serializers/ledger/traintuple/util.py b/substrabac/substrapp/serializers/ledger/traintuple/util.py deleted file mode 100644 index 8ca26f4b8..000000000 --- a/substrabac/substrapp/serializers/ledger/traintuple/util.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import absolute_import, unicode_literals - - -from substrapp.utils import invokeLedger - - -def createLedgerTraintuple(args, sync=False): - options = { - 'args': '{"Args":["createTraintuple", ' + args + ']}' - } - return invokeLedger(options, sync) diff --git a/substrabac/substrapp/signals/datasample/pre_save.py b/substrabac/substrapp/signals/datasample/pre_save.py deleted file mode 100644 index 92d4bee10..000000000 --- a/substrabac/substrapp/signals/datasample/pre_save.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -import shutil -from os import path, rename, link, walk, makedirs -from os.path import normpath - -from checksumdir import dirhash -from django.conf import settings -from django.core.files import File - -from substrapp.utils import uncompress_content, create_directory - - -def create_hard_links(base_dir, directory): - makedirs(directory, exist_ok=True) - for root, subdirs, files in walk(base_dir): - for file in files: - link(path.join(root, file), path.join(directory, file)) - for subdir in subdirs: - create_hard_links(root, subdir) - - -def data_sample_pre_save(sender, instance, **kwargs): - directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'datasamples/{0}'.format(instance.pk)) - - # uncompress file if an archive - if isinstance(instance.path, File): - try: - content = instance.path.read() - instance.path.seek(0) - uncompress_content(content, directory) - except Exception as e: - logging.info(e) - raise e - else: - # calculate new hash - sha256hash = dirhash(directory, 'sha256') - # rename directory to new hash if does not exist - new_directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'datasamples', sha256hash) - try: - rename(directory, new_directory) - except Exception as e: - # directory already exists with same exact data sample inside - # created by a previous save, delete directory entitled pkhash - # for avoiding duplicates - shutil.rmtree(directory) - logging.error(e, exc_info=True) - - # override defaults - instance.pkhash = sha256hash - instance.path = new_directory - # make an hardlink on all files if a path - else: - try: - p = normpath(instance.path) - create_hard_links(p, directory) - except Exception as e: - pass - else: - # override path for getting our hardlink - instance.path = directory diff --git a/substrabac/substrapp/task_utils.py b/substrabac/substrapp/task_utils.py deleted file mode 100644 index f4568cccc..000000000 --- a/substrabac/substrapp/task_utils.py +++ /dev/null @@ -1,332 +0,0 @@ -import os -import docker -import GPUtil as gputil -import threading -import time - -import logging -from django.conf import settings - -DOCKER_LABEL = 'substra_task' - - -def get_cpu_sets(cpu_count, concurrency): - cpu_step = max(1, cpu_count // concurrency) - cpu_sets = [] - - for cpu_start in range(0, cpu_count, cpu_step): - cpu_set = f'{cpu_start}-{min(cpu_start + cpu_step - 1, cpu_count - 1)}' - cpu_sets.append(cpu_set) - if len(cpu_sets) == concurrency: - break - - return cpu_sets - - -def get_gpu_sets(gpu_list, concurrency): - - if gpu_list: - gpu_count = len(gpu_list) - gpu_step = max(1, gpu_count // concurrency) - gpu_sets = [] - - for igpu_start in range(0, gpu_count, gpu_step): - gpu_sets.append(','.join(gpu_list[igpu_start: igpu_start + gpu_step])) - else: - gpu_sets = None - - return gpu_sets - - -def expand_cpu_set(cpu_set): - cpu_set_start, cpu_set_stop = map(int, cpu_set.split('-')) - return set(range(cpu_set_start, cpu_set_stop + 1)) - - -def reduce_cpu_set(expanded_cpu_set): - return f'{min(expanded_cpu_set)}-{max(expanded_cpu_set)}' - - -def expand_gpu_set(gpu_set): - return set(gpu_set.split(',')) - - -def reduce_gpu_set(expanded_gpu_set): - return ','.join(sorted(expanded_gpu_set)) - - -def filter_resources_sets(used_resources_sets, resources_sets, expand_resources_set, reduce_resources_set): - """ Filter resources_set used with resources_sets defined. - It will block a resources_set from resources_sets if an used_resources_set in a subset of a resources_set""" - - resources_expand = [expand_resources_set(resources_set) for resources_set in resources_sets] - used_resources_expand = [expand_resources_set(used_resources_set) for used_resources_set in used_resources_sets] - - real_used_resources_sets = [] - - for resources_set in resources_expand: - for used_resources_set in used_resources_expand: - if resources_set.intersection(used_resources_set): - real_used_resources_sets.append(reduce_resources_set(resources_set)) - break - - return list(set(resources_sets).difference(set(real_used_resources_sets))) - - -def filter_cpu_sets(used_cpu_sets, cpu_sets): - return filter_resources_sets(used_cpu_sets, cpu_sets, expand_cpu_set, reduce_cpu_set) - - -def filter_gpu_sets(used_gpu_sets, gpu_sets): - return filter_resources_sets(used_gpu_sets, gpu_sets, expand_gpu_set, reduce_gpu_set) - - -def update_statistics(task_statistics, stats, gpu_stats): - - # CPU - - if stats is not None: - - if 'cpu_stats' in stats and stats['cpu_stats']['cpu_usage'].get('total_usage', None): - # Compute CPU usage in % - delta_total_usage = (stats['cpu_stats']['cpu_usage']['total_usage'] - stats['precpu_stats']['cpu_usage']['total_usage']) - delta_system_usage = (stats['cpu_stats']['system_cpu_usage'] - stats['precpu_stats']['system_cpu_usage']) - total_usage = (delta_total_usage / delta_system_usage) * stats['cpu_stats']['online_cpus'] * 100.0 - - task_statistics['cpu']['current'].append(total_usage) - task_statistics['cpu']['max'] = max(task_statistics['cpu']['max'], - max(task_statistics['cpu']['current'])) - - # MEMORY in GB - if 'memory_stats' in stats: - current_usage = stats['memory_stats'].get('usage', None) - max_usage = stats['memory_stats'].get('max_usage', None) - - if current_usage: - task_statistics['memory']['current'].append(current_usage / 1024**3) - if max_usage: - task_statistics['memory']['max'] = max(task_statistics['memory']['max'], - max_usage / 1024**3, - max(task_statistics['memory']['current'])) - - # Network in kB - if 'networks' in stats: - task_statistics['netio']['rx'] = stats['networks']['eth0'].get('rx_bytes', 0) - task_statistics['netio']['tx'] = stats['networks']['eth0'].get('tx_bytes', 0) - - # GPU - - if gpu_stats is not None: - total_usage = sum([100 * gpu.load for gpu in gpu_stats]) - task_statistics['gpu']['current'].append(total_usage) - task_statistics['gpu']['max'] = max(task_statistics['gpu']['max'], - max(task_statistics['gpu']['current'])) - - total_usage = sum([gpu.memoryUsed for gpu in gpu_stats]) / 1024 - task_statistics['gpu_memory']['current'].append(total_usage) - task_statistics['gpu_memory']['max'] = max(task_statistics['gpu_memory']['max'], - max(task_statistics['gpu_memory']['current'])) - - -def monitoring_task(client, task_args): - """thread worker function""" - - task_name = task_args['name'] - - gpu_set = None - if 'environment' in task_args: - gpu_set = task_args['environment']['NVIDIA_VISIBLE_DEVICES'] - - start = time.time() - t = threading.currentThread() - - # Statistics - task_statistics = {'memory': {'max': 0, - 'current': [0]}, - 'gpu_memory': {'max': 0, - 'current': [0]}, - 'cpu': {'max': 0, - 'current': [0]}, - 'gpu': {'max': 0, - 'current': []}, - 'io': {'max': 0, - 'current': []}, - 'netio': {'rx': 0, - 'tx': 0}, - 'time': 0} - - while not t.stopthread.isSet(): - stats = None - try: - container = client.containers.get(task_name) - stats = container.stats(decode=True, stream=False) - except (docker.errors.NotFound, docker.errors.APIError): - pass - - gpu_stats = None - if gpu_set is not None: - gpu_stats = [gpu for gpu in gputil.getGPUs() if str(gpu.id) in gpu_set] - - update_statistics(task_statistics, stats, gpu_stats) - - task_statistics['time'] = time.time() - start - - t._stats = task_statistics - - t._result = f"CPU:{t._stats['cpu']['max']:.2f} % - Mem:{t._stats['memory']['max']:.2f}" - t._result += f" GB - GPU:{t._stats['gpu']['max']:.2f} % - GPU Mem:{t._stats['gpu_memory']['max']:.2f} GB" - - -def compute_docker(client, resources_manager, dockerfile_path, image_name, container_name, volumes, command, remove_image=True): - - dockerfile_fullpath = os.path.join(dockerfile_path, 'Dockerfile') - if not os.path.exists(dockerfile_fullpath): - raise Exception(f'Dockerfile does not exist : {dockerfile_fullpath}') - - # Build metrics - client.images.build(path=dockerfile_path, - tag=image_name, - rm=remove_image) - - # Limit ressources - memory_limit_mb = f'{resources_manager.memory_limit_mb()}M' - cpu_set, gpu_set = resources_manager.get_cpu_gpu_sets() # blocking call - - task_args = {'image': image_name, - 'name': container_name, - 'cpuset_cpus': cpu_set, - 'mem_limit': memory_limit_mb, - 'command': command, - 'volumes': volumes, - 'shm_size': '8G', - 'labels': [DOCKER_LABEL], - 'detach': False, - 'auto_remove': False, - 'remove': False} - - if gpu_set is not None: - task_args['environment'] = {'NVIDIA_VISIBLE_DEVICES': gpu_set} - task_args['runtime'] = 'nvidia' - - task = ExceptionThread(target=client.containers.run, - kwargs=task_args) - - monitoring = ExceptionThread(target=monitoring_task, - args=(client, task_args)) - - task.start() - monitoring.start() - - task.join() - monitoring.join() - - # Remove container in all case (exception thrown or not) - # We already get the excetion and we need to remove the containers to be able to remove the local - # volume in case of fl task. - container = client.containers.get(container_name) - container.remove() - - # Remove images - if remove_image or hasattr(task, "_exception"): - client.images.remove(image_name, force=True) - - if hasattr(task, "_exception"): - raise task._exception - - return monitoring._result - - -class ExceptionThread(threading.Thread): - - def __init__(self, *args, **kwargs): - super(ExceptionThread, self).__init__(*args, **kwargs) - self.stopthread = threading.Event() - - def run(self): - try: - if self._target: - self._target(*self._args, **self._kwargs) - except BaseException as e: - self._exception = e - raise e - finally: - # Avoid a refcycle if the thread is running a function with - # an argument that has a member that points to the thread. - del self._target, self._args, self._kwargs - - def join(self, timeout=None): - self.stopthread.set() - super(ExceptionThread, self).join(timeout) - - -class ResourcesManager(): - - __concurrency = int(getattr(settings, 'CELERY_WORKER_CONCURRENCY')) - __memory_mb = int(os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') / (1024. ** 2)) - __memory_mb_per_task = __memory_mb // __concurrency - - __cpu_count = os.cpu_count() - __cpu_sets = get_cpu_sets(__cpu_count, __concurrency) - - # Set CUDA_DEVICE_ORDER so the IDs assigned by CUDA match those from nvidia-smi - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - __gpu_list = [str(gpu.id) for gpu in gputil.getGPUs()] - __gpu_sets = get_gpu_sets(__gpu_list, __concurrency) # Can be None if no gpu - - __lock = threading.Lock() - __docker = docker.from_env() - - @classmethod - def memory_limit_mb(cls): - return cls.__memory_mb_per_task - - @classmethod - def get_cpu_gpu_sets(cls): - - cpu_set = None - gpu_set = None - - with cls.__lock: - # We can just wait for cpu because cpu and gpu is allocated the same way - while cpu_set is None: - - # Get ressources used - filters = {'status': 'running', - 'label': [DOCKER_LABEL]} - - try: - containers = [container.attrs - for container in cls.__docker.containers.list(filters=filters)] - except docker.errors.APIError as e: - logging.error(e, exc_info=True) - continue - - # CPU - used_cpu_sets = [container['HostConfig']['CpusetCpus'] - for container in containers - if container['HostConfig']['CpusetCpus']] - - cpu_sets_available = filter_cpu_sets(used_cpu_sets, cls.__cpu_sets) - - if cpu_sets_available: - cpu_set = cpu_sets_available.pop() - - # GPU - if cls.__gpu_sets is not None: - env_containers = [container['Config']['Env'] - for container in containers] - - used_gpu_sets = [] - - for env_list in env_containers: - nvidia_env_var = [s.split('=')[1] - for s in env_list if "NVIDIA_VISIBLE_DEVICES" in s] - - used_gpu_sets.extend(nvidia_env_var) - - gpu_sets_available = filter_gpu_sets(used_gpu_sets, cls.__gpu_sets) - - if gpu_sets_available: - gpu_set = gpu_sets_available.pop() - - return cpu_set, gpu_set diff --git a/substrabac/substrapp/tasks.py b/substrabac/substrapp/tasks.py deleted file mode 100644 index da89344b3..000000000 --- a/substrabac/substrapp/tasks.py +++ /dev/null @@ -1,510 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import os -import shutil -import tempfile -from os import path - -from checksumdir import dirhash -from django.conf import settings -from rest_framework import status -from rest_framework.reverse import reverse - -from substrabac.celery import app -from substrapp.utils import queryLedger, invokeLedger, get_hash, create_directory, get_remote_file, uncompress_content -from substrapp.task_utils import ResourcesManager, compute_docker -from substrapp.exception_handler import compute_error_code - -import docker -import json -from multiprocessing.managers import BaseManager - -import logging - - -def get_objective(subtuple): - from substrapp.models import Objective - - # check if objective exists and its metrics is not null - objectiveHash = subtuple['objective']['hash'] - - try: - # get objective from local db - objective = Objective.objects.get(pk=objectiveHash) - except: - objective = None - finally: - if objective is None or not objective.metrics: - # get objective metrics - try: - content, computed_hash = get_remote_file(subtuple['objective']['metrics']) - except Exception as e: - raise e - - objective, created = Objective.objects.update_or_create(pkhash=objectiveHash, validated=True) - - try: - f = tempfile.TemporaryFile() - f.write(content) - objective.metrics.save('metrics.py', f) # update objective in local db for later use - except Exception as e: - logging.error('Failed to save objective metrics in local db for later use') - raise e - - return objective - - -def get_algo(subtuple): - algo_content, algo_computed_hash = get_remote_file(subtuple['algo']) - return algo_content, algo_computed_hash - - -def get_model(subtuple): - model_content, model_computed_hash = None, None - - if subtuple.get('model', None) is not None: - model_content, model_computed_hash = get_remote_file(subtuple['model'], subtuple['model']['traintupleKey']) - - return model_content, model_computed_hash - - -def get_models(subtuple): - models_content, models_computed_hash = [], [] - - if subtuple.get('inModels', None) is not None: - for subtuple_model in subtuple['inModels']: - model_content, model_computed_hash = get_remote_file(subtuple_model, subtuple_model['traintupleKey']) - models_content.append(model_content) - models_computed_hash.append(model_computed_hash) - - return models_content, models_computed_hash - - -def put_model(subtuple, subtuple_directory, model_content): - if model_content is not None: - from substrapp.models import Model - - model_dst_path = path.join(subtuple_directory, f'model/{subtuple["model"]["traintupleKey"]}') - - try: - model = Model.objects.get(pk=subtuple['model']['hash']) - except: # write it to local disk - with open(model_dst_path, 'wb') as f: - f.write(model_content) - else: - if get_hash(model.file.path, subtuple["model"]["traintupleKey"]) != subtuple['model']['hash']: - raise Exception('Model Hash in Subtuple is not the same as in local db') - - if not os.path.exists(model_dst_path): - os.link(model.file.path, model_dst_path) - else: - if get_hash(model_dst_path, subtuple["model"]["traintupleKey"]) != subtuple['model']['hash']: - raise Exception('Model Hash in Subtuple is not the same as in local medias') - - -def put_models(subtuple, subtuple_directory, models_content): - if models_content: - from substrapp.models import Model - - for model_content, subtuple_model in zip(models_content, subtuple['inModels']): - model_dst_path = path.join(subtuple_directory, f'model/{subtuple_model["traintupleKey"]}') - - try: - model = Model.objects.get(pk=subtuple_model['hash']) - except: # write it to local disk - with open(model_dst_path, 'wb') as f: - f.write(model_content) - else: - if get_hash(model.file.path, subtuple_model["traintupleKey"]) != subtuple_model['hash']: - raise Exception('Model Hash in Subtuple is not the same as in local db') - - if not os.path.exists(model_dst_path): - os.link(model.file.path, model_dst_path) - else: - if get_hash(model_dst_path, subtuple_model["traintupleKey"]) != subtuple_model['hash']: - raise Exception('Model Hash in Subtuple is not the same as in local medias') - - -def put_opener(subtuple, subtuple_directory): - from substrapp.models import DataManager - - try: - datamanager = DataManager.objects.get(pk=subtuple['dataset']['openerHash']) - except Exception as e: - raise e - - data_opener_hash = get_hash(datamanager.data_opener.path) - if data_opener_hash != subtuple['dataset']['openerHash']: - raise Exception('DataOpener Hash in Subtuple is not the same as in local db') - - opener_dst_path = path.join(subtuple_directory, 'opener/opener.py') - if not os.path.exists(opener_dst_path): - os.link(datamanager.data_opener.path, opener_dst_path) - - -def put_data_sample(subtuple, subtuple_directory): - from substrapp.models import DataSample - - for data_sample_key in subtuple['dataset']['keys']: - try: - data_sample = DataSample.objects.get(pk=data_sample_key) - except Exception as e: - raise e - else: - data_sample_hash = dirhash(data_sample.path, 'sha256') - if data_sample_hash != data_sample_key: - raise Exception('Data Sample Hash in Subtuple is not the same as in local db') - - # create a symlink on the folder containing data - try: - subtuple_data_directory = path.join(subtuple_directory, 'data', data_sample_key) - os.symlink(data_sample.path, subtuple_data_directory) - except Exception as e: - logging.error(e, exc_info=True) - raise Exception('Failed to create sym link for subtuple data sample') - - -def put_metric(subtuple_directory, objective): - metrics_dst_path = path.join(subtuple_directory, 'metrics/metrics.py') - if not os.path.exists(metrics_dst_path): - os.link(objective.metrics.path, metrics_dst_path) - - -def put_algo(subtuple_directory, algo_content): - try: - uncompress_content(algo_content, subtuple_directory) - except Exception as e: - logging.error('Fail to uncompress algo file') - raise e - - -def build_subtuple_folders(subtuple): - # create a folder named subtuple['key'] im /medias/subtuple with 5 folders opener, data, model, pred, metrics - subtuple_directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'subtuple', subtuple['key']) - create_directory(subtuple_directory) - for folder in ['opener', 'data', 'model', 'pred', 'metrics']: - create_directory(path.join(subtuple_directory, folder)) - - return subtuple_directory - - -def remove_subtuple_materials(subtuple_directory): - try: - shutil.rmtree(subtuple_directory) - except Exception as e: - logging.error(e) - - -def fail(key, err_msg, tuple_type): - # Log Fail TrainTest - err_msg = str(err_msg).replace('"', "'").replace('\\', "").replace('\\n', "")[:200] - fail_type = 'logFailTrain' if tuple_type == 'traintuple' else 'logFailTest' - data, st = invokeLedger({ - 'args': f'{{"Args":["{fail_type}","{key}","{err_msg}"]}}' - }, sync=True) - - if st is not status.HTTP_201_CREATED: - logging.error(data, exc_info=True) - - logging.info('Successfully passed the subtuple to failed') - return data - - -# Instatiate Ressource Manager in BaseManager to share it between celery concurrent tasks -BaseManager.register('ResourcesManager', ResourcesManager) -manager = BaseManager() -manager.start() -resources_manager = manager.ResourcesManager() - - -def prepareTask(tuple_type, model_type): - from django_celery_results.models import TaskResult - - try: - data_owner = get_hash(settings.LEDGER['signcert']) - except Exception as e: - logging.error(e, exc_info=True) - else: - subtuples, st = queryLedger({ - 'args': f'{{"Args":["queryFilter","{tuple_type}~worker~status","{data_owner},todo"]}}' - }) - - if st == status.HTTP_200_OK and subtuples is not None: - for subtuple in subtuples: - - fltask = None - worker_queue = f"{settings.LEDGER['name']}.worker" - - if 'fltask' in subtuple and subtuple['fltask']: - fltask = subtuple['fltask'] - flresults = TaskResult.objects.filter( - task_name='substrapp.tasks.computeTask', - result__icontains=f'"fltask": "{fltask}"') - - if flresults and flresults.count() > 0: - worker_queue = json.loads(flresults.first().as_dict()['result'])['worker'] - - try: - # Log Start of the Subtuple - start_type = 'logStartTrain' if tuple_type == 'traintuple' else 'logStartTest' if tuple_type == 'testtuple' else None - data, st = invokeLedger({ - 'args': f'{{"Args":["{start_type}","{subtuple["key"]}"]}}' - }, sync=True) - - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - logging.error( - f'Failed to invoke ledger on prepareTask {tuple_type}. Error: {data}') - else: - computeTask.apply_async( - (tuple_type, subtuple, model_type, fltask), - queue=worker_queue) - - except Exception as e: - error_code = compute_error_code(e) - logging.error(error_code, exc_info=True) - return fail(subtuple['key'], error_code, tuple_type) - - -@app.task(bind=True, ignore_result=True) -def prepareTrainingTask(self): - prepareTask('traintuple', 'inModels') - - -@app.task(ignore_result=True) -def prepareTestingTask(): - prepareTask('testtuple', 'model') - - -@app.task(bind=True, ignore_result=False) -def computeTask(self, tuple_type, subtuple, model_type, fltask): - - try: - worker = self.request.hostname.split('@')[1] - queue = self.request.delivery_info['routing_key'] - except: - worker = f"{settings.LEDGER['name']}.worker" - queue = f"{settings.LEDGER['name']}" - - result = {'worker': worker, 'queue': queue, 'fltask': fltask} - - # Get materials - try: - prepareMaterials(subtuple, model_type) - except Exception as e: - error_code = compute_error_code(e) - logging.error(error_code, exc_info=True) - fail(subtuple['key'], error_code, tuple_type) - return result - - logging.info(f'Prepare Task success {tuple_type}') - - try: - res = doTask(subtuple, tuple_type) - except Exception as e: - error_code = compute_error_code(e) - logging.error(error_code, exc_info=True) - fail(subtuple['key'], error_code, tuple_type) - return result - else: - # Invoke ledger to log success - if tuple_type == 'traintuple': - invoke_args = f'{{"Args":["logSuccessTrain","{subtuple["key"]}", "{res["end_model_file_hash"]}, {res["end_model_file"]}","{res["global_perf"]}","Train - {res["job_task_log"]}; "]}}' - elif tuple_type == 'testtuple': - invoke_args = f'{{"Args":["logSuccessTest","{subtuple["key"]}","{res["global_perf"]}","Test - {res["job_task_log"]}; "]}}' - - data, st = invokeLedger({ - 'args': invoke_args - }, sync=True) - - if st not in (status.HTTP_201_CREATED, status.HTTP_408_REQUEST_TIMEOUT): - logging.error('Failed to invoke ledger on logSuccess') - logging.error(data) - - return result - - -def prepareMaterials(subtuple, model_type): - # get subtuple components - try: - objective = get_objective(subtuple) - algo_content, algo_computed_hash = get_algo(subtuple) - if model_type == 'model': - model_content, model_computed_hash = get_model(subtuple) # can return None, None - if model_type == 'inModels': - models_content, models_computed_hash = get_models(subtuple) # can return [], [] - - except Exception as e: - raise e - - # create subtuple - try: - subtuple_directory = build_subtuple_folders(subtuple) # do not put anything in pred folder - put_opener(subtuple, subtuple_directory) - put_data_sample(subtuple, subtuple_directory) - put_metric(subtuple_directory, objective) - put_algo(subtuple_directory, algo_content) - if model_type == 'model': # testtuple - put_model(subtuple, subtuple_directory, model_content) - if model_type == 'inModels': # traintuple - put_models(subtuple, subtuple_directory, models_content) - - except Exception as e: - raise e - - -def doTask(subtuple, tuple_type): - subtuple_directory = path.join(getattr(settings, 'MEDIA_ROOT'), 'subtuple', subtuple['key']) - org_name = getattr(settings, 'ORG_NAME') - - # Federated learning variables - fltask = None - flrank = None - - if 'fltask' in subtuple and subtuple['fltask']: - fltask = subtuple['fltask'] - flrank = int(subtuple['rank']) - - # Computation - try: - # Job log - job_task_log = '' - - # Setup Docker Client - client = docker.from_env() - - # subtuple setup - model_path = path.join(subtuple_directory, 'model') - data_path = path.join(subtuple_directory, 'data') - - ########################################## - # RESOLVE SYMLINKS - # TO DO: - # - Verify that real paths are safe - # - Try to see if it's clean to do that - ########################################## - symlinks_volume = {} - for subfolder in os.listdir(data_path): - real_path = os.path.realpath(os.path.join(data_path, subfolder)) - symlinks_volume[real_path] = {'bind': f'{real_path}', 'mode': 'ro'} - - ########################################## - - pred_path = path.join(subtuple_directory, 'pred') - opener_file = path.join(subtuple_directory, 'opener/opener.py') - metrics_file = path.join(subtuple_directory, 'metrics/metrics.py') - volumes = {data_path: {'bind': '/sandbox/data', 'mode': 'ro'}, - pred_path: {'bind': '/sandbox/pred', 'mode': 'rw'}, - metrics_file: {'bind': '/sandbox/metrics/__init__.py', 'mode': 'ro'}, - opener_file: {'bind': '/sandbox/opener/__init__.py', 'mode': 'ro'}} - - # compute algo task - algo_path = path.join(subtuple_directory) - algo_docker = f'algo_{tuple_type}'.lower() # tag must be lowercase for docker - algo_docker_name = f'{algo_docker}_{subtuple["key"]}' - model_volume = {model_path: {'bind': '/sandbox/model', 'mode': 'rw'}} - - if fltask is not None and flrank != -1: - remove_image = False - else: - remove_image = True - - # create the command option for algo - if tuple_type == 'traintuple': - algo_command = 'train' # main command - - # add list of inmodels - if subtuple['inModels'] is not None: - inmodels = [subtuple_model["traintupleKey"] for subtuple_model in subtuple['inModels']] - algo_command = f"{algo_command} {' '.join(inmodels)}" - - # add fltask rank for training - if flrank is not None: - algo_command = f"{algo_command} --rank {flrank}" - - elif tuple_type == 'testtuple': - algo_command = 'predict' # main command - - inmodels = subtuple['model']["traintupleKey"] - algo_command = f'{algo_command} {inmodels}' - - # local volume for fltask - if fltask is not None and tuple_type == 'traintuple': - flvolume = f'local-{fltask}-{org_name}' - if flrank == 0: - client.volumes.create(name=flvolume) - else: - client.volumes.get(volume_id=flvolume) - - model_volume[flvolume] = {'bind': '/sandbox/local', 'mode': 'rw'} - - job_task_log = compute_docker(client=client, - resources_manager=resources_manager, - dockerfile_path=algo_path, - image_name=algo_docker, - container_name=algo_docker_name, - volumes={**volumes, **model_volume, **symlinks_volume}, - command=algo_command, - remove_image=remove_image) - # save model in database - if tuple_type == 'traintuple': - from substrapp.models import Model - end_model_path = path.join(subtuple_directory, 'model/model') - end_model_file_hash = get_hash(end_model_path, subtuple['key']) - try: - instance = Model.objects.create(pkhash=end_model_file_hash, validated=True) - except Exception as e: - error_code = compute_error_code(e) - logging.error(error_code, exc_info=True) - return fail(subtuple['key'], error_code, tuple_type) - - with open(end_model_path, 'rb') as f: - instance.file.save('model', f) - current_site = getattr(settings, "DEFAULT_DOMAIN") - end_model_file = f'{current_site}{reverse("substrapp:model-file", args=[end_model_file_hash])}' - - # compute metric task - metrics_path = path.join(getattr(settings, 'PROJECT_ROOT'), 'base_metrics') # base metrics comes with substrabac - metrics_docker = f'metrics_{tuple_type}'.lower() # tag must be lowercase for docker - metrics_docker_name = f'{metrics_docker}_{subtuple["key"]}' - compute_docker(client=client, - resources_manager=resources_manager, - dockerfile_path=metrics_path, - image_name=metrics_docker, - container_name=metrics_docker_name, - volumes={**volumes, **symlinks_volume}, - command=None, - remove_image=remove_image) - - # load performance - with open(path.join(pred_path, 'perf.json'), 'r') as perf_file: - perf = json.load(perf_file) - global_perf = perf['all'] - - except Exception as e: - # If an exception is thrown set flrank == -1 (we stop the fl training) - if fltask is not None: - flrank = -1 - - raise e - else: - result = {'global_perf': global_perf, - 'job_task_log': job_task_log} - - if tuple_type == 'traintuple': - result['end_model_file_hash'] = end_model_file_hash - result['end_model_file'] = end_model_file - - finally: - # Clean subtuple materials - remove_subtuple_materials(subtuple_directory) - - # Rank == -1 -> Last fl subtuple or fl throws an exception - if flrank == -1: - flvolume = f'local-{fltask}-{org_name}' - local_volume = client.volumes.get(volume_id=flvolume) - try: - local_volume.remove(force=True) - except: - logging.error(f'Cannot remove local volume {flvolume}', exc_info=True) - - return result diff --git a/substrabac/substrapp/tests/assets.py b/substrabac/substrapp/tests/assets.py deleted file mode 100644 index 24c6fd30e..000000000 --- a/substrabac/substrapp/tests/assets.py +++ /dev/null @@ -1,340 +0,0 @@ -objective = [ - { - "key": "1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3", - "name": "Skin Lesion Classification Objective", - "description": { - "hash": "1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3", - "storageAddress": "http://testserver/objective/1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3/description/" - }, - "metrics": { - "name": "macro-average recall", - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/1cdafbb018dd195690111d74916b76c96892d897ec3587c814f287946db446c3/metrics/" - }, - "owner": "fba9c2538319fe2b45ac7047e21b4bc7196537367814d5da7f0aae020d3be5f7", - "testDataset": None, - "permissions": "all" - }, - { - "key": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "name": "Skin Lesion Classification Objective", - "description": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/description/" - }, - "metrics": { - "name": "macro-average recall", - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - }, - "owner": "fba9c2538319fe2b45ac7047e21b4bc7196537367814d5da7f0aae020d3be5f7", - "testDataset": { - "dataManagerKey": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "dataSampleKeys": [ - "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf", - "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a" - ] - }, - "permissions": "all" - } -] - -datamanager = [ - { - "objectiveKey": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "description": { - "hash": "15863c2af1fcfee9ca6f61f04be8a0eaaf6a45e4d50c421788d450d198e580f1", - "storageAddress": "http://testserver/data_manager/615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7/description/" - }, - "key": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "name": "ISIC 2018", - "opener": { - "hash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "storageAddress": "http://testserver/data_manager/615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7/opener/" - }, - "owner": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "permissions": "all", - "type": "Images" - }, - { - "objectiveKey": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "description": { - "hash": "258bef187a166b3fef5cb86e68c8f7e154c283a148cd5bc344fec7e698821ad3", - "storageAddress": "http://testserver/data_manager/615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7/description/" - }, - "key": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "name": "Simplified ISIC 2018", - "opener": { - "hash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "storageAddress": "http://testserver/data_manager/615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7/opener/" - }, - "owner": "fba9c2538319fe2b45ac7047e21b4bc7196537367814d5da7f0aae020d3be5f7", - "permissions": "all", - "type": "Images" - } -] - -algo = [ - { - "key": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", - "name": "Neural Network", - "content": { - "hash": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", - "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/file/" - }, - "description": { - "hash": "b9463411a01ea00869bdffce6e59a5c100a4e635c0a9386266cad3c77eb28e9e", - "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/description/" - }, - "owner": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "permissions": "all" - }, - { - "key": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "name": "Logistic regression", - "content": { - "hash": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/file/" - }, - "description": { - "hash": "124a0425b746d7072282d167b53cb6aab3a31bf1946dae89135c15b0126ebec3", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/description/" - }, - "owner": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "permissions": "all" - }, - { - "key": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", - "name": "Random Forest", - "content": { - "hash": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", - "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/file/" - }, - "description": { - "hash": "4acea40c4b51996c88ef279c5c9aa41ab77b97d38c5ca167e978a98b2e402675", - "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/description/" - }, - "owner": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "permissions": "all" - } -] - -traintuple = [ - { - "algo": { - "hash": "0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f", - "name": "Neural Network", - "storageAddress": "http://testserver/algo/0acc5180e09b6a6ac250f4e3c172e2893f617aa1c22ef1f379019d20fe44142f/file/" - }, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "keys": [ - "31510dc1d8be788f7c5d28d05714f7efb9edb667762966b9adc02eadeaacebe9", - "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 0, - "worker": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426" - }, - "fltask": "", - "inModels": None, - "key": "c4e3116dd3f895986b77e4d445178330630bd3f52407f10462dd4778e40090e0", - "log": "[00-01-0032-7cc5b61]", - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "outModel": None, - "permissions": "all", - "rank": 0, - "status": "failed", - "tag": "My super tag" - }, - { - "algo": { - "hash": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "name": "Logistic regression", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/file/" - }, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "keys": [ - "31510dc1d8be788f7c5d28d05714f7efb9edb667762966b9adc02eadeaacebe9", - "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 1, - "worker": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426" - }, - "fltask": "", - "inModels": None, - "key": "3979576752e014adddadfc360d79c67cdccb0f4bae46936f35ce09c64e5832c8", - "log": "Train - CPU:173.81 % - Mem:0.11 GB - GPU:0.00 % - GPU Mem:0.00 GB; ", - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "outModel": { - "hash": "592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075", - "storageAddress": "http://testserver/model/592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075/file/" - }, - "permissions": "all", - "rank": 0, - "status": "done", - "tag": "substra" - }, - { - "algo": { - "hash": "9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9", - "name": "Random Forest", - "storageAddress": "http://testserver/algo/9c3d8777e11fd72cbc0fd672bec3a0848f8518b4d56706008cc05f8a1cee44f9/file/" - }, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "keys": [ - "31510dc1d8be788f7c5d28d05714f7efb9edb667762966b9adc02eadeaacebe9", - "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 0, - "worker": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426" - }, - "fltask": "", - "inModels": None, - "key": "c6beed3a4ee5ead0c4246faac7931a944fc2286e193454bb1b851dee0c5a5f59", - "log": "[00-01-0032-139c39e]", - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "outModel": None, - "permissions": "all", - "rank": 0, - "status": "failed", - "tag": "" - } -] - -testtuple = [ - { - "key": "b7b9291e5ff96ec7d16d38ab49915cbe15055347bb933a824887f2a76fb57c9a", - "algo": { - "name": "Logistic regression", - "hash": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/file/" - }, - "certified": True, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "worker": "fba9c2538319fe2b45ac7047e21b4bc7196537367814d5da7f0aae020d3be5f7", - "keys": [ - "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a", - "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 0 - }, - "log": "Test - CPU:179.46 % - Mem:0.09 GB - GPU:0.00 % - GPU Mem:0.00 GB; ", - "model": { - "traintupleKey": "3979576752e014adddadfc360d79c67cdccb0f4bae46936f35ce09c64e5832c8", - "hash": "592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075", - "storageAddress": "http://testserver/model/592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075/file/" - }, - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "permissions": "all", - "status": "done", - "tag": "" - } -] - -model = [ - { - "testtuple": { - "algo": { - "hash": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "name": "Logistic regression", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/file/" - }, - "certified": True, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "keys": [ - "17d58b67ae2028018108c9bf555fa58b2ddcfe560e0117294196e79d26140b2a", - "8bf3bf4f753a32f27d18c86405e7a406a83a55610d91abcca9acc525061b8ecf" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 0, - "worker": "fba9c2538319fe2b45ac7047e21b4bc7196537367814d5da7f0aae020d3be5f7" - }, - "key": "b7b9291e5ff96ec7d16d38ab49915cbe15055347bb933a824887f2a76fb57c9a", - "log": "Test - CPU:179.46 % - Mem:0.09 GB - GPU:0.00 % - GPU Mem:0.00 GB; ", - "model": { - "hash": "592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075", - "storageAddress": "http://testserver/model/592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075/file/", - "traintupleKey": "3979576752e014adddadfc360d79c67cdccb0f4bae46936f35ce09c64e5832c8" - }, - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "permissions": "all", - "status": "done", - "tag": "" - }, - "traintuple": { - "algo": { - "hash": "4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7", - "name": "Logistic regression", - "storageAddress": "http://testserver/algo/4cc53726e01f7e3864a6cf9da24d9cef04a7cbd7fd2892765ff76931dd4628e7/file/" - }, - "creator": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426", - "dataset": { - "keys": [ - "31510dc1d8be788f7c5d28d05714f7efb9edb667762966b9adc02eadeaacebe9", - "03a1f878768ea8624942d46a3b438c37992e626c2cf655023bcc3bed69d485d1" - ], - "openerHash": "615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7", - "perf": 1, - "worker": "2cb13d299b337fae2969da1ff4ddd9a2f3004be52d64f671d13d9513f5a79426" - }, - "fltask": "", - "inModels": None, - "key": "3979576752e014adddadfc360d79c67cdccb0f4bae46936f35ce09c64e5832c8", - "log": "Train - CPU:173.81 % - Mem:0.11 GB - GPU:0.00 % - GPU Mem:0.00 GB; ", - "objective": { - "hash": "3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71", - "metrics": { - "hash": "c42dca31fbc2ebb5705643e3bb6ee666bbfd956de13dd03727f825ad8445b4d7", - "storageAddress": "http://testserver/objective/3d70ab46d710dacb0f48cb42db4874fac14e048a0d415e266aad38c09591ee71/metrics/" - } - }, - "outModel": { - "hash": "592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075", - "storageAddress": "http://testserver/model/592242f9b162178994897c5b8aa49450a17cc395bb9bc9864b830a6cdba6a075/file/" - }, - "permissions": "all", - "rank": 0, - "status": "done", - "tag": "substra" - } - } -] - diff --git a/substrabac/substrapp/tests/common.py b/substrabac/substrapp/tests/common.py deleted file mode 100644 index 2185d35fc..000000000 --- a/substrabac/substrapp/tests/common.py +++ /dev/null @@ -1,316 +0,0 @@ -from io import StringIO, BytesIO -import os - -from django.core.files.uploadedfile import InMemoryUploadedFile - -class gpu(): - """Fake gpu""" - - def __init__(self): - self.load = 0.8 - self.memoryUsed = 1024 - - -class Stats(): - @classmethod - def get_stats(cls): - """ Docker stats""" - return {"read": "2018-11-05T13:44:07.1782391Z", - "preread": "2018-11-05T13:44:06.1746531Z", - "pids_stats": { - "current": 8 - }, - "num_procs": 0, - "storage_stats": {}, - "cpu_stats": { - "cpu_usage": { - "total_usage": 22900421851, - "percpu_usage": [ - 4944824970, - 4964929089, - 8163433379, - 4827234413, - 0, - 0, - 0, - 0 - ], - "usage_in_kernelmode": 5520000000, - "usage_in_usermode": 17350000000 - }, - "system_cpu_usage": 185691120000000, - "online_cpus": 8, - "throttling_data": { - "periods": 0, - "throttled_periods": 0, - "throttled_time": 0 - }}, - "precpu_stats": { - "cpu_usage": { - "total_usage": 18898246805, - "percpu_usage": [ - 3938977859, - 3966955357, - 7165817747, - 3826495842, - 0, - 0, - 0, - 0 - ], - "usage_in_kernelmode": 5470000000, - "usage_in_usermode": 13390000000 - }, - "system_cpu_usage": 185683050000000, - "online_cpus": 8, - "throttling_data": { - "periods": 0, - "throttled_periods": 0, - "throttled_time": 0 - } - }, - "memory_stats": { - "usage": 1404354560, - "max_usage": 1404616704, - "limit": 8589934592 - }, - "name": "/job_c9868", - "id": "60fa7ab1c6dafdaa08ec3e2b95b16120757ac5cb7ebd512b3526b2d521623776", - "networks": { - "eth0": { - "rx_bytes": 758, - "rx_packets": 9, - "rx_errors": 0, - "rx_dropped": 0, - "tx_bytes": 0, - "tx_packets": 0, - "tx_errors": 0, - "tx_dropped": 0 - } - }} - - -class JobStats(): - - @classmethod - def get_new_stats(cls): - return {'memory': {'max': 0, - 'current': [0]}, - 'gpu_memory': {'max': 0, - 'current': [0]}, - 'cpu': {'max': 0, - 'current': [0]}, - 'gpu': {'max': 0, - 'current': []}, - 'io': {'max': 0, - 'current': []}, - 'netio': {'rx': 0, - 'tx': 0}, - 'time': 0} - - -def get_temporary_text_file(contents, filename): - """ - Creates a temporary text file - - :param contents: contents of the file - :param filename: name of the file - :type contents: str - :type filename: str - """ - f = StringIO() - flength = f.write(contents) - text_file = InMemoryUploadedFile(f, None, filename, 'text', flength, None) - # Setting the file to its start - text_file.seek(0) - return text_file - - -def get_sample_objective(): - description_content = "Super objective" - description_filename = "description.md" - description = get_temporary_text_file(description_content, description_filename) - metrics_content = "def metrics():\n\tpass" - metrics_filename = "metrics.py" - metrics = get_temporary_text_file(metrics_content, metrics_filename) - - return description, description_filename, metrics, metrics_filename - - -def get_sample_script(): - script_content = "import slidelib\n\ndef read():\n\tpass" - script_filename = "script.py" - script = get_temporary_text_file(script_content, script_filename) - - return script, script_filename - - -def get_sample_datamanager(): - description_content = "description" - description_filename = "description.md" - description = get_temporary_text_file(description_content, description_filename) - data_opener_content = "import slidelib\n\ndef read():\n\tpass" - data_opener_filename = "data_opener.py" - data_opener = get_temporary_text_file(data_opener_content, data_opener_filename) - - return description, description_filename, data_opener, data_opener_filename - - -def get_sample_datamanager2(): - description_content = "description 2" - description_filename = "description2.md" - description = get_temporary_text_file(description_content, description_filename) - data_opener_content = "import os\nimport slidelib\n\ndef read():\n\tpass" - data_opener_filename = "data_opener2.py" - data_opener = get_temporary_text_file(data_opener_content, data_opener_filename) - - return description, description_filename, data_opener, data_opener_filename - - -def get_sample_data_sample(): - file_content = "0\n1\n2" - file_filename = "file.csv" - file = get_temporary_text_file(file_content, file_filename) - - return file, file_filename - - -def get_sample_zip_data_sample(): - dir_path = os.path.dirname(os.path.realpath(__file__)) - file_filename = "file.zip" - f = BytesIO(b'foo') - with open(os.path.join(dir_path, '../../fixtures/owkin/datasamples/datasample4/0024900.zip'), 'rb') as zip_file: - flength = f.write(zip_file.read()) - - file = InMemoryUploadedFile(f, None, file_filename, - 'application/zip', flength, None) - file.seek(0) - - return file, file_filename - - -def get_sample_zip_data_sample_2(): - dir_path = os.path.dirname(os.path.realpath(__file__)) - file_filename = "file.zip" - f = BytesIO(b'foo') - with open(os.path.join(dir_path, '../../fixtures/owkin/datasamples/test/0024901.zip'), 'rb') as zip_file: - flength = f.write(zip_file.read()) - - file = InMemoryUploadedFile(f, None, file_filename, - 'application/zip', flength, None) - file.seek(0) - - return file, file_filename - - -def get_sample_tar_data_sample(): - dir_path = os.path.dirname(os.path.realpath(__file__)) - file_filename = "file.tar.gz" - f = BytesIO() - with open(os.path.join(dir_path, '../../fixtures/owkin/datasamples/datasample4/0024900.tar.gz'), 'rb') as tar_file: - flength = f.write(tar_file.read()) - - file = InMemoryUploadedFile(f, None, file_filename, - 'application/zip', flength, None) - file.seek(0) - - return file, file_filename - - -def get_sample_algo(): - dir_path = os.path.dirname(os.path.realpath(__file__)) - file_filename = "file.tar.gz" - f = BytesIO() - with open(os.path.join(dir_path, '../../fixtures/chunantes/algos/algo3/algo.tar.gz'), 'rb') as tar_file: - flength = f.write(tar_file.read()) - - file = InMemoryUploadedFile(f, None, file_filename, - 'application/tar+gzip', flength, None) - file.seek(0) - - return file, file_filename - - -def get_sample_model(): - model_content = "0.1, 0.2, -1.0" - model_filename = "model.bin" - model = get_temporary_text_file(model_content, model_filename) - - return model, model_filename - - -class FakeContainer(object): - def __init__(self): - self.c_stats = Stats.get_stats() - - def stats(self, decode, stream): - return self.c_stats - - -class FakeClient(object): - def __init__(self): - self.containers = {'job': FakeContainer()} - - -class FakeMetrics(object): - def __init__(self, filepath='path'): - self.path = filepath - - def save(self, p, f): - return - - -class FakeObjective(object): - def __init__(self, filepath='path'): - self.metrics = FakeMetrics(filepath) - - -class FakeOpener(object): - def __init__(self, filepath): - self.path = filepath - self.name = self.path - - -class FakeDataManager(object): - def __init__(self, filepath): - self.data_opener = FakeOpener(filepath) - - -class FakeFilterDataManager(object): - def __init__(self, count): - self.count_value = count - - def count(self): - return self.count_value - - -class FakePath(object): - def __init__(self, filepath): - self.path = filepath - - -class FakeModel(object): - def __init__(self, filepath): - self.file = FakePath(filepath) - - -class FakeAsyncResult(object): - def __init__(self, status=None, successful=True): - if status is not None: - self.status = status - self.success = successful - self.result = {'res': 'result'} - - def successful(self): - return self.success - - -class FakeRequest(object): - def __init__(self, status, content): - self.status_code = status - self.content = content - - -class FakeTask(object): - def __init__(self, task_id): - self.id = task_id diff --git a/substrabac/substrapp/tests/tests_misc.py b/substrabac/substrapp/tests/tests_misc.py deleted file mode 100644 index 8ee372c0d..000000000 --- a/substrabac/substrapp/tests/tests_misc.py +++ /dev/null @@ -1,76 +0,0 @@ -from django.test import TestCase - -from mock import patch -from substrapp.task_utils import get_cpu_sets, get_gpu_sets, ExceptionThread, \ - update_statistics - -from substrapp.tests.common import JobStats, Stats, gpu - -class MockDevice(): - """A mock device to temporarily suppress output to stdout - Similar to UNIX /dev/null. - """ - - def write(self, s): - pass - - -class MiscTests(TestCase): - """Misc tests""" - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_cpu_sets(self): - cpu_count = 16 - for concurrency in range(1, cpu_count + 1, 1): - self.assertEqual(concurrency, - len(get_cpu_sets(cpu_count, concurrency))) - - def test_gpu_sets(self): - gpu_list = ['0', '1'] - for concurrency in range(1, len(gpu_list) + 1, 1): - self.assertEqual(concurrency, - len(get_gpu_sets(gpu_list, concurrency))) - - self.assertFalse(get_gpu_sets([], concurrency)) - - def test_exception_thread(self): - - training = ExceptionThread(target=lambda x, y: x / y, - args=(3, 0), - daemon=True) - - with patch('sys.stderr', new=MockDevice()): - training.start() - training.join() - - self.assertTrue(hasattr(training, '_exception')) - with self.assertRaises(ZeroDivisionError): - raise training._exception - - def test_update_statistics(self): - - # Statistics - - job_statistics = JobStats.get_new_stats() - tmp_statistics = JobStats.get_new_stats() - - update_statistics(job_statistics, None, None) - self.assertEqual(tmp_statistics, job_statistics) - - update_statistics(job_statistics, None, [gpu()]) - self.assertNotEqual(tmp_statistics, job_statistics) - self.assertEqual(job_statistics['gpu']['max'], 80) - self.assertEqual(job_statistics['gpu_memory']['max'], 1) - - job_statistics = JobStats.get_new_stats() - tmp_statistics = JobStats.get_new_stats() - update_statistics(job_statistics, Stats.get_stats(), None) - self.assertNotEqual(tmp_statistics, job_statistics) - self.assertNotEqual(job_statistics['memory']['max'], 0) - self.assertNotEqual(job_statistics['cpu']['max'], 0) - self.assertNotEqual(job_statistics['netio']['rx'], 0) diff --git a/substrabac/substrapp/tests/tests_query.py b/substrabac/substrapp/tests/tests_query.py deleted file mode 100644 index 0a604e695..000000000 --- a/substrabac/substrapp/tests/tests_query.py +++ /dev/null @@ -1,1190 +0,0 @@ -import os -import shutil -import tempfile -import zipfile -from unittest.mock import MagicMock - -import mock -from django.core.files import File -from django.core.files.uploadedfile import InMemoryUploadedFile - -from django.urls import reverse -from django.test import override_settings - -from rest_framework import status -from rest_framework.test import APITestCase - -from substrapp.models import Objective, DataManager, Algo, DataSample -from substrapp.serializers import LedgerObjectiveSerializer, \ - LedgerDataManagerSerializer, LedgerAlgoSerializer, \ - LedgerDataSampleSerializer, LedgerTrainTupleSerializer, DataSampleSerializer -from substrapp.utils import get_hash, compute_hash, get_dir_hash -from substrapp.views import DataSampleViewSet - -from .common import get_sample_objective, get_sample_datamanager, \ - get_sample_zip_data_sample, get_sample_script, \ - get_temporary_text_file, get_sample_datamanager2, get_sample_algo, \ - get_sample_tar_data_sample, get_sample_zip_data_sample_2 - -MEDIA_ROOT = tempfile.mkdtemp() - - -# APITestCase - -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class ObjectiveQueryTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.objective_description, self.objective_description_filename, \ - self.objective_metrics, self.objective_metrics_filename = get_sample_objective() - - self.data_description, self.data_description_filename, self.data_data_opener, \ - self.data_opener_filename = get_sample_datamanager() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def add_default_data_manager(self): - DataManager.objects.create(name='slide opener', - description=self.data_description, - data_opener=self.data_data_opener) - - def get_default_objective_data(self): - # XXX reload fixtures as it is an opened buffer and a post will - # modify the objects - desc, _, metrics, _ = get_sample_objective() - - expected_hash = get_hash(self.objective_description) - data = { - 'name': 'tough objective', - 'test_data_manager_key': get_hash(self.data_data_opener), - 'test_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379', - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389'], - 'description': desc, - 'metrics': metrics, - 'permissions': 'all', - 'metrics_name': 'accuracy' - } - return expected_hash, data - - def test_add_objective_sync_ok(self): - self.add_default_data_manager() - - pkhash, data = self.get_default_objective_data() - - url = reverse('substrapp:objective-list') - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerObjectiveSerializer, 'create') as mcreate: - mcreate.return_value = {'pkhash': pkhash}, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r['pkhash'], pkhash) - self.assertEqual(r['validated'], False) - self.assertEqual(r['description'], - f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_description_filename}') - self.assertEqual(r['metrics'], - f'http://testserver/media/objectives/{r["pkhash"]}/{self.objective_metrics_filename}') - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - def test_add_objective_conflict(self): - self.add_default_data_manager() - - pkhash, data = self.get_default_objective_data() - - url = reverse('substrapp:objective-list') - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerObjectiveSerializer, 'create') as mcreate: - mcreate.return_value = {'pkhash': pkhash}, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r['pkhash'], pkhash) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - # XXX reload data as the previous call to post change it - _, data = self.get_default_objective_data() - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) - self.assertEqual(r['pkhash'], pkhash) - - def test_add_objective_no_sync_ok(self): - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:objective-list') - - data = { - 'name': 'tough objective', - 'test_data_manager_key': get_hash(self.data_data_opener), - 'test_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379', - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389'], - 'description': self.objective_description, - 'metrics': self.objective_metrics, - 'permissions': 'all', - 'metrics_name': 'accuracy' - } - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - with mock.patch.object(LedgerObjectiveSerializer, 'create') as mcreate: - mcreate.return_value = {'message': 'Objective added in local db waiting for validation. \ - The substra network has been notified for adding this Objective'}, status.HTTP_202_ACCEPTED - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - def test_add_objective_ko(self): - url = reverse('substrapp:objective-list') - - data = {'name': 'empty objective'} - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - data = {'metrics': self.objective_metrics, - 'description': self.objective_description} - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_objective_no_version(self): - url = reverse('substrapp:objective-list') - - description_content = 'My Super top objective' - metrics_content = 'def metrics():\n\tpass' - - description = get_temporary_text_file(description_content, - 'description.md') - metrics = get_temporary_text_file(metrics_content, 'metrics.py') - - data = { - 'name': 'tough objective', - 'test_data_sample_keys': [ - 'data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379', - 'data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389'], - 'description': description, - 'metrics': metrics, - } - - response = self.client.post(url, data, format='multipart') - r = response.json() - - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_add_objective_wrong_version(self): - url = reverse('substrapp:objective-list') - - description_content = 'My Super top objective' - metrics_content = 'def metrics():\n\tpass' - - description = get_temporary_text_file(description_content, - 'description.md') - metrics = get_temporary_text_file(metrics_content, 'metrics.py') - - data = { - 'name': 'tough objective', - 'test_data_sample_keys': [ - 'data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379', - 'data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389'], - 'description': description, - 'metrics': metrics, - } - - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_get_objective_metrics(self): - objective = Objective.objects.create( - description=self.objective_description, - metrics=self.objective_metrics) - with mock.patch( - 'substrapp.views.utils.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.return_value = self.objective_metrics - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - response = self.client.get( - f'/objective/{objective.pkhash}/metrics/', **extra) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertNotEqual(objective.pkhash, - compute_hash(response.getvalue())) - self.assertEqual(self.objective_metrics_filename, - response.filename) - # self.assertEqual(r, f'http://testserver/media/objectives/{objective.pkhash}/{self.objective_metrics_filename}') - - def test_get_objective_metrics_no_version(self): - objective = Objective.objects.create( - description=self.objective_description, - metrics=self.objective_metrics) - response = self.client.get(f'/objective/{objective.pkhash}/metrics/') - r = response.json() - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_get_objective_metrics_wrong_version(self): - objective = Objective.objects.create( - description=self.objective_description, - metrics=self.objective_metrics) - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - response = self.client.get(f'/objective/{objective.pkhash}/metrics/', - **extra) - r = response.json() - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class DataManagerQueryTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.data_description, self.data_description_filename, self.data_data_opener, \ - self.data_opener_filename = get_sample_datamanager() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def test_add_datamanager_sync_ok(self): - url = reverse('substrapp:data_manager-list') - - data = { - 'name': 'slide opener', - 'type': 'images', - 'permissions': 'all', - 'objective_key': '', - 'description': self.data_description, - 'data_opener': self.data_data_opener - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerDataManagerSerializer, 'create') as mcreate: - mcreate.return_value = { - 'pkhash': 'da920c804c4724f1ce7bd0484edcf4aafa209d5bd54e2e89972c087a487cbe02'}, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r['pkhash'], get_hash(self.data_data_opener)) - self.assertEqual(r['description'], - f'http://testserver/media/datamanagers/{r["pkhash"]}/{self.data_description_filename}') - - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - def test_add_datamanager_no_sync_ok(self): - url = reverse('substrapp:data_manager-list') - data = { - 'name': 'slide opener', - 'type': 'images', - 'permissions': 'all', - 'objective_key': '', - 'description': self.data_description, - 'data_opener': self.data_data_opener - } - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - with mock.patch.object(LedgerDataManagerSerializer, 'create') as mcreate: - mcreate.return_value = {'message': 'DataManager added in local db waiting for validation. \ - The substra network has been notified for adding this DataManager'}, status.HTTP_202_ACCEPTED - response = self.client.post(url, data, format='multipart', **extra) - - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - def test_add_datamanager_ko(self): - url = reverse('substrapp:data_manager-list') - - data = {'name': 'toto'} - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_datamanager_no_version(self): - url = reverse('substrapp:data_manager-list') - - data = { - 'name': 'slide opener', - 'description': self.data_description, - 'data_opener': self.data_data_opener - } - response = self.client.post(url, data, format='multipart') - r = response.json() - - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_add_datamanager_wrong_version(self): - url = reverse('substrapp:data_manager-list') - - data = { - 'name': 'slide opener', - 'type': 'images', - 'permissions': 'all', - 'objective_key': '', - 'description': self.data_description, - 'data_opener': self.data_data_opener - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class DataSampleQueryTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.script, self.script_filename = get_sample_script() - self.data_file, self.data_file_filename = get_sample_zip_data_sample() - self.data_file_2, self.data_file_filename_2 = get_sample_zip_data_sample_2() - self.data_tar_file, self.data_tar_file_filename = get_sample_tar_data_sample() - - self.data_description, self.data_description_filename, self.data_data_opener, \ - self.data_opener_filename = get_sample_datamanager() - - self.data_description2, self.data_description_filename2, self.data_data_opener2, \ - self.data_opener_filename2 = get_sample_datamanager2() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def test_add_data_sample_sync_ok(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - data = { - 'file': self.data_file, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mcreate.return_value = { - 'pkhash': '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553', - 'validated': True}, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.data_file.file.seek(0) - self.assertEqual(r[0]['pkhash'], get_dir_hash(self.data_file.file)) - - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - def test_bulk_add_data_sample_sync_ok(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock2 = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock2.name = 'bar.zip' - file_mock.read = MagicMock(return_value=self.data_file.read()) - file_mock2.read = MagicMock(return_value=self.data_file_2.read()) - - data = { - file_mock.name: file_mock, - file_mock2.name: file_mock2, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch('substrapp.serializers.datasample.DataSampleSerializer.get_validators') as mget_validators, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mget_validators.return_value = [] - self.data_file.seek(0) - self.data_file_2.seek(0) - ledger_data = {'pkhash': [get_dir_hash(file_mock), get_dir_hash(file_mock2)], 'validated': True} - mcreate.return_value = ledger_data, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(len(r), 2) - self.assertEqual(r[0]['pkhash'], get_dir_hash(file_mock)) - self.assertTrue(r[0]['path'].endswith(f'/datasamples/{get_dir_hash(file_mock)}')) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - def test_add_data_sample_no_sync_ok(self): - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - url = reverse('substrapp:data_sample-list') - data = { - 'file': self.data_file, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - with mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mcreate.return_value = {'message': 'Data added in local db waiting for validation. \ - The substra network has been notified for adding this Data'}, status.HTTP_202_ACCEPTED - response = self.client.post(url, data, format='multipart', **extra) - - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - def test_add_data_sample_ko(self): - url = reverse('substrapp:data_sample-list') - - # missing datamanager - data = {'data_manager_keys': ['toto']} - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], - "One or more datamanager keys provided do not exist in local substrabac database. Please create them before. DataManager keys: ['toto']") - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - # missing local storage field - data = {'data_manager_keys': [get_hash(self.data_description)], - 'test_only': True, } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # missing ledger field - data = {'data_manager_keys': [get_hash(self.data_description)], - 'file': self.script, } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_ko_already_exists(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=self.data_file.file.read()) - file_mock.open = MagicMock(return_value=file_mock) - - d = DataSample(path=file_mock) - # trigger pre save - d.save() - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile: - mis_zipfile.return_value = True - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], - [[{'pkhash': ['data sample with this pkhash already exists.']}]]) - self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) - - def test_add_data_sample_ko_not_a_zip(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=File) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=b'foo') - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], 'Archive must be zip or tar.*') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_ko_408(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=self.data_file.file.read()) - file_mock.open = MagicMock(return_value=file_mock) - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mcreate.return_value = {'pkhash': get_hash(file_mock), 'validated': False}, status.HTTP_408_REQUEST_TIMEOUT - mis_zipfile.return_value = True - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], {'pkhash': [get_dir_hash(file_mock)], 'validated': False}) - self.assertEqual(response.status_code, status.HTTP_408_REQUEST_TIMEOUT) - - def test_bulk_add_data_sample_ko_408(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock2 = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock2.name = 'bar.zip' - file_mock.read = MagicMock(return_value=self.data_file.read()) - file_mock2.read = MagicMock(return_value=self.data_file_2.read()) - - data = { - file_mock.name: file_mock, - file_mock2.name: file_mock2, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch('substrapp.serializers.datasample.DataSampleSerializer.get_validators') as mget_validators, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mget_validators.return_value = [] - self.data_file.seek(0) - self.data_tar_file.seek(0) - ledger_data = {'pkhash': [get_dir_hash(file_mock), get_dir_hash(file_mock2)], 'validated': False} - mcreate.return_value = ledger_data, status.HTTP_408_REQUEST_TIMEOUT - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message']['validated'], False) - self.assertEqual(DataSample.objects.count(), 2) - self.assertEqual(response.status_code, status.HTTP_408_REQUEST_TIMEOUT) - - def test_bulk_add_data_sample_ko_same_pkhash(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock2 = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock2.name = 'bar.tar.gz' - file_mock.read = MagicMock(return_value=self.data_file.read()) - file_mock2.read = MagicMock(return_value=self.data_tar_file.read()) - - data = { - file_mock.name: file_mock, - file_mock2.name: file_mock2, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch('substrapp.serializers.datasample.DataSampleSerializer.get_validators') as mget_validators, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mget_validators.return_value = [] - self.data_file.seek(0) - self.data_tar_file.seek(0) - ledger_data = {'pkhash': [get_dir_hash(file_mock), get_dir_hash(file_mock2)], 'validated': False} - mcreate.return_value = ledger_data, status.HTTP_408_REQUEST_TIMEOUT - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(DataSample.objects.count(), 0) - self.assertEqual(r['message'], f'Your data sample archives contain same files leading to same pkhash, please review the content of your achives. Archives {file_mock2.name} and {file_mock.name} are the same') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_ko_400(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=self.data_file.file.read()) - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - mcreate.return_value = 'Failed', status.HTTP_400_BAD_REQUEST - mis_zipfile.return_value = True - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], 'Failed') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_ko_serializer_invalid(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=self.data_file.read()) - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ - mock.patch.object(DataSampleViewSet, 'get_serializer') as mget_serializer: - mocked_serializer = MagicMock(DataSampleSerializer) - mocked_serializer.is_valid.return_value = True - mocked_serializer.save.side_effect = Exception('Failed') - mget_serializer.return_value = mocked_serializer - - mis_zipfile.return_value = True - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], "Failed") - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_ko_ledger_invalid(self): - url = reverse('substrapp:data_sample-list') - - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - file_mock = MagicMock(spec=InMemoryUploadedFile) - file_mock.name = 'foo.zip' - file_mock.read = MagicMock(return_value=self.data_file.file.read()) - - data = { - 'file': file_mock, - 'data_manager_keys': [get_hash(self.data_data_opener)], - 'test_only': True, - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(zipfile, 'is_zipfile') as mis_zipfile, \ - mock.patch('substrapp.views.datasample.LedgerDataSampleSerializer', spec=True) as mLedgerDataSampleSerializer: - mocked_LedgerDataSampleSerializer = MagicMock() - mocked_LedgerDataSampleSerializer.is_valid.return_value = False - mocked_LedgerDataSampleSerializer.errors = 'Failed' - mLedgerDataSampleSerializer.return_value = mocked_LedgerDataSampleSerializer - - mis_zipfile.return_value = True - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['message'], "[ErrorDetail(string='Failed', code='invalid')]") - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_data_sample_no_version(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - data = { - 'file': self.data_file, - 'data_manager_keys': [get_hash(self.data_description)], - 'test_only': True, - } - response = self.client.post(url, data, format='multipart') - r = response.json() - - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_add_data_sample_wrong_version(self): - - # add associated data opener - datamanager_name = 'slide opener' - DataManager.objects.create(name=datamanager_name, - description=self.data_description, - data_opener=self.data_data_opener) - - url = reverse('substrapp:data_sample-list') - - data = { - 'file': self.script, - 'data_manager_keys': [datamanager_name], - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_bulk_update_data(self): - - # add associated data opener - datamanager = DataManager.objects.create(name='slide opener', - description=self.data_description, - data_opener=self.data_data_opener) - datamanager2 = DataManager.objects.create(name='slide opener 2', - description=self.data_description2, - data_opener=self.data_data_opener2) - - d = DataSample(path=self.data_file) - # trigger pre save - d.save() - - url = reverse('substrapp:data_sample-bulk-update') - - data = { - 'data_manager_keys': [datamanager.pkhash, datamanager2.pkhash], - 'data_keys': [d.pkhash], - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch( - 'substrapp.serializers.ledger.datasample.util.invokeLedger') as minvokeLedger: - minvokeLedger.return_value = {'keys': [ - d.pkhash]}, status.HTTP_200_OK - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r['keys'], [d.pkhash]) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class AlgoQueryTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.objective_description, self.objective_description_filename, \ - self.objective_metrics, self.objective_metrics_filename = get_sample_objective() - - self.algo, self.algo_filename = get_sample_algo() - - self.data_description, self.data_description_filename, self.data_data_opener, \ - self.data_opener_filename = get_sample_datamanager() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def test_add_algo_sync_ok(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(dir_path, '../../fixtures/chunantes/algos/algo3/algo.tar.gz'), 'rb') as tar_file: - algo_content = tar_file.read() - - # add associated objective - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - - url = reverse('substrapp:algo-list') - - data = { - 'file': self.algo, - 'description': self.data_description, - 'name': 'super top algo', - 'objective_key': get_hash(self.objective_description), - 'permissions': 'all' - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: - mcreate.return_value = {'pkhash': compute_hash(algo_content)}, status.HTTP_201_CREATED - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r['pkhash'], compute_hash(algo_content)) - - def test_add_algo_no_sync_ok(self): - # add associated objective - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - url = reverse('substrapp:algo-list') - data = { - 'file': self.algo, - 'description': self.data_description, - 'name': 'super top algo', - 'objective_key': get_hash(self.objective_description), - 'permissions': 'all' - } - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: - mcreate.return_value = {'message': 'Algo added in local db waiting for validation. \ - The substra network has been notified for adding this Algo'}, status.HTTP_202_ACCEPTED - response = self.client.post(url, data, format='multipart', **extra) - - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - def test_add_algo_ko(self): - url = reverse('substrapp:algo-list') - - # non existing associated objective - data = { - 'file': self.algo, - 'description': self.data_description, - 'name': 'super top algo', - 'objective_key': 'non existing objectivexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', - 'permissions': 'all' - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: - mcreate.return_value = { - 'message': 'Fail to add algo. Objective does not exist'}, status.HTTP_400_BAD_REQUEST - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertIn('does not exist', r['message']) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - - # missing local storage field - data = { - 'name': 'super top algo', - 'objective_key': get_hash(self.objective_description), - 'permissions': 'all' - } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # missing ledger field - data = { - 'file': self.algo, - 'description': self.data_description, - 'objective_key': get_hash(self.objective_description), - } - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_algo_no_version(self): - - # add associated objective - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - - url = reverse('substrapp:algo-list') - - data = { - 'file': self.algo, - 'description': self.data_description, - 'name': 'super top algo', - 'objective_key': get_hash(self.objective_description), - 'permissions': 'all' - } - response = self.client.post(url, data, format='multipart') - r = response.json() - - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_add_algo_wrong_version(self): - - # add associated objective - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - - url = reverse('substrapp:algo-list') - - data = { - 'file': self.algo, - 'description': self.data_description, - 'name': 'super top algo', - 'objective_key': get_hash(self.objective_description), - 'permissions': 'all' - } - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_get_algo_files(self): - algo = Algo.objects.create(file=self.algo) - with mock.patch( - 'substrapp.views.utils.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.return_value = self.algo - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - response = self.client.get(f'/algo/{algo.pkhash}/file/', **extra) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(algo.pkhash, compute_hash(response.getvalue())) - # self.assertEqual(r, f'http://testserver/media/algos/{algo.pkhash}/{self.algo_filename}') - - def test_get_algo_files_no_version(self): - algo = Algo.objects.create(file=self.algo) - response = self.client.get(f'/algo/{algo.pkhash}/file/') - r = response.json() - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_get_algo_files_wrong_version(self): - algo = Algo.objects.create(file=self.algo) - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - response = self.client.get(f'/algo/{algo.pkhash}/file/', **extra) - r = response.json() - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class TraintupleQueryTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.objective_description, self.objective_description_filename, \ - self.objective_metrics, self.objective_metrics_filename = get_sample_objective() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def test_add_traintuple_ok(self): - # Add associated objective - description, _, metrics, _ = get_sample_objective() - Objective.objects.create(description=description, - metrics=metrics) - # post data - url = reverse('substrapp:traintuple-list') - - data = {'train_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'], - 'algo_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'data_manager_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'objective_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'rank': -1, - 'FLtask_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'in_models_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422']} - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - with mock.patch.object(LedgerTrainTupleSerializer, 'create') as mcreate, \ - mock.patch('substrapp.views.traintuple.queryLedger') as mqueryLedger: - - raw_pkhash = 'traintuple_pkhash'.encode('utf-8').hex() - mqueryLedger.return_value = ({'key': raw_pkhash}, status.HTTP_200_OK) - mcreate.return_value = {'message': 'Traintuple added in local db waiting for validation. \ - The substra network has been notified for adding this Traintuple'}, status.HTTP_202_ACCEPTED - - response = self.client.post(url, data, format='multipart', **extra) - - print(response.json()) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - def test_add_traintuple_ko(self): - url = reverse('substrapp:traintuple-list') - - data = {'train_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'], - 'model_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088'} - - extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0', - } - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertIn('This field may not be null.', r['algo_key']) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - Objective.objects.create(description=self.objective_description, - metrics=self.objective_metrics) - data = {'objective': get_hash(self.objective_description)} - response = self.client.post(url, data, format='multipart', **extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_add_traintuple_no_version(self): - # Add associated objective - description, _, metrics, _ = get_sample_objective() - Objective.objects.create(description=description, - metrics=metrics) - # post data - url = reverse('substrapp:traintuple-list') - - data = {'train_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'], - 'datamanager_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'model_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'algo_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088'} - - response = self.client.post(url, data, format='multipart') - r = response.json() - self.assertEqual(r, {'detail': 'A version is required.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) - - def test_add_traintuple_wrong_version(self): - # Add associated objective - description, _, metrics, _ = get_sample_objective() - Objective.objects.create(description=description, - metrics=metrics) - # post data - url = reverse('substrapp:traintuple-list') - - data = {'train_data_sample_keys': [ - '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b422'], - 'datamanager_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'model_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088', - 'algo_key': '5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0a088'} - extra = { - 'HTTP_ACCEPT': 'application/json;version=-1.0', - } - - response = self.client.post(url, data, format='multipart', **extra) - r = response.json() - self.assertEqual(r, {'detail': 'Invalid version in "Accept" header.'}) - self.assertEqual(response.status_code, status.HTTP_406_NOT_ACCEPTABLE) diff --git a/substrabac/substrapp/tests/tests_tasks.py b/substrabac/substrapp/tests/tests_tasks.py deleted file mode 100644 index c289be5ea..000000000 --- a/substrabac/substrapp/tests/tests_tasks.py +++ /dev/null @@ -1,538 +0,0 @@ -import os -import shutil -import mock -import time -import uuid - -from django.test import override_settings -from django.http import HttpResponse -from rest_framework import status -from rest_framework.test import APITestCase - -from substrapp.models import DataSample -from substrapp.utils import compute_hash, get_computed_hash, get_remote_file, get_hash, create_directory -from substrapp.task_utils import ResourcesManager, monitoring_task, compute_docker, ExceptionThread -from substrapp.tasks import build_subtuple_folders, get_algo, get_model, get_models, get_objective, put_opener, put_model, put_models, put_algo, put_metric, put_data_sample, prepareTask, doTask, computeTask - -from .common import get_sample_algo, get_sample_script, get_sample_zip_data_sample, get_sample_tar_data_sample, get_sample_model -from .common import FakeClient, FakeObjective, FakeDataManager, FakeModel - -import zipfile -import docker -MEDIA_ROOT = "/tmp/unittests_tasks/" -# MEDIA_ROOT = tempfile.mkdtemp() - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -class TasksTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.subtuple_path = MEDIA_ROOT - - self.script, self.script_filename = get_sample_script() - - self.algo, self.algo_filename = get_sample_algo() - self.data_sample, self.data_sample_filename = get_sample_zip_data_sample() - self.data_sample_tar, self.data_sample_tar_filename = get_sample_tar_data_sample() - self.model, self.model_filename = get_sample_model() - - self.ResourcesManager = ResourcesManager() - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - def test_create_directory(self): - directory = './test/' - create_directory(directory) - self.assertTrue(os.path.exists(directory)) - shutil.rmtree(directory) - - def test_get_computed_hash(self): - with mock.patch('substrapp.utils.requests.get') as mget: - mget.return_value = HttpResponse(str(self.script.read())) - _, pkhash = get_computed_hash('localhost') - self.assertEqual(pkhash, get_hash(self.script)) - - with mock.patch('substrapp.utils.requests.get') as mget: - mget.return_value = HttpResponse() - mget.return_value.status_code = status.HTTP_400_BAD_REQUEST - with self.assertRaises(Exception): - get_computed_hash('localhost') - - def test_get_remote_file(self): - content = str(self.script.read()) - remote_file = {'storageAddress': 'localhost', - 'hash': compute_hash(content)} - - with mock.patch('substrapp.utils.get_computed_hash') as mget_computed_hash: - pkhash = compute_hash(content) - mget_computed_hash.return_value = content, pkhash - - content_remote, pkhash_remote = get_remote_file(remote_file) - self.assertEqual(pkhash_remote, get_hash(self.script)) - self.assertEqual(content_remote, content) - - with mock.patch('substrapp.utils.get_computed_hash') as mget_computed_hash: - content = content + ' FAIL' - pkhash = compute_hash(content) - mget_computed_hash.return_value = content, pkhash - - with self.assertRaises(Exception): - get_remote_file(remote_file) # contents (by pkhash) are different - - def test_Ressource_Manager(self): - - self.assertTrue(isinstance(self.ResourcesManager.memory_limit_mb(), int)) - - cpu_set, gpu_set = self.ResourcesManager.get_cpu_gpu_sets() - self.assertIn(cpu_set, self.ResourcesManager._ResourcesManager__cpu_sets) - - if gpu_set is not None: - self.assertIn(gpu_set, self.ResourcesManager._ResourcesManager__gpu_sets) - - def test_monitoring_task(self): - - monitoring = ExceptionThread(target=monitoring_task, args=(FakeClient(), {'name': 'job'})) - monitoring.start() - time.sleep(0.1) - monitoring.join() - - self.assertNotEqual(monitoring._stats['memory']['max'], 0) - self.assertNotEqual(monitoring._stats['cpu']['max'], 0) - self.assertNotEqual(monitoring._stats['netio']['rx'], 0) - - def test_put_algo_tar(self): - algo_content = self.algo.read() - subtuple_key = get_hash(self.algo) - - subtuple = {'key': subtuple_key, - 'algo': 'testalgo'} - - with mock.patch('substrapp.tasks.get_hash') as mget_hash: - mget_hash.return_value = subtuple_key - put_algo(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/'), algo_content) - - def tree_printer(root): - for root, dirs, files in os.walk(root): - for d in dirs: - print(os.path.join(root, d)) - for f in files: - print(os.path.join(root, f)) - - self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/algo.py'))) - self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/Dockerfile'))) - - def test_put_algo_zip(self): - filename = 'algo.py' - filepath = os.path.join(self.subtuple_path, filename) - with open(filepath, 'w') as f: - f.write('Hello World') - self.assertTrue(os.path.exists(filepath)) - - zipname = 'sample.zip' - zippath = os.path.join(self.subtuple_path, zipname) - with zipfile.ZipFile(zippath, mode='w') as zf: - zf.write(filepath, arcname=filename) - self.assertTrue(os.path.exists(zippath)) - - subtuple_key = 'testkey' - subtuple = {'key': subtuple_key, 'algo': 'testalgo'} - - with mock.patch('substrapp.tasks.get_hash') as mget_hash: - with open(zippath, 'rb') as content: - mget_hash.return_value = get_hash(zippath) - put_algo(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/'), content.read()) - - self.assertTrue(os.path.exists(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}/{filename}'))) - - def test_put_metric(self): - - filepath = os.path.join(self.subtuple_path, self.script_filename) - with open(filepath, 'w') as f: - f.write(self.script.read()) - self.assertTrue(os.path.exists(filepath)) - - metrics_directory = os.path.join(self.subtuple_path, 'metrics/') - create_directory(metrics_directory) - - put_metric(self.subtuple_path, FakeObjective(filepath)) - self.assertTrue(os.path.exists(os.path.join(metrics_directory, 'metrics.py'))) - - def test_put_opener(self): - - filepath = os.path.join(self.subtuple_path, self.script_filename) - with open(filepath, 'w') as f: - f.write(self.script.read()) - self.assertTrue(os.path.exists(filepath)) - - opener_directory = os.path.join(self.subtuple_path, 'opener') - create_directory(opener_directory) - - with mock.patch('substrapp.models.DataManager.objects.get') as mget: - mget.return_value = FakeDataManager(filepath) - - # test fail - with self.assertRaises(Exception): - put_opener({'dataset': {'openerHash': 'HASH'}}, self.subtuple_path) - - # test work - put_opener({'dataset': {'openerHash': get_hash(filepath)}}, self.subtuple_path) - - self.assertTrue(os.path.exists(os.path.join(opener_directory, 'opener.py'))) - - def test_put_data_sample_zip(self): - - data_sample = DataSample(pkhash='foo', path=self.data_sample) - data_sample.save() - - subtuple = { - 'key': 'bar', - 'dataset': {'keys': [data_sample.pk]} - } - - with mock.patch('substrapp.models.DataSample.objects.get') as mget: - mget.return_value = data_sample - - subtuple_direcory = build_subtuple_folders(subtuple) - - put_data_sample(subtuple, subtuple_direcory) - - # check folder has been correctly renamed with pk of directory containing uncompressed data sample - self.assertFalse( - os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', 'foo'))) - dir_pkhash = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' - self.assertTrue( - os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', dir_pkhash))) - - # check subtuple folder has been created and sym links exists - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) - self.assertTrue(os.path.islink(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'LABEL_0024900.csv'))) - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'IMG_0024900.jpg'))) - - def test_put_data_tar(self): - - data_sample = DataSample(pkhash='foo', path=self.data_sample_tar) - data_sample.save() - - subtuple = { - 'key': 'bar', - 'dataset': {'keys': [data_sample.pk]} - } - - with mock.patch('substrapp.models.DataSample.objects.get') as mget: - mget.return_value = data_sample - - subtuple_direcory = build_subtuple_folders(subtuple) - - put_data_sample(subtuple, subtuple_direcory) - - # check folder has been correctly renamed with pk of directory containing uncompressed data_sample - self.assertFalse(os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', 'foo'))) - dir_pkhash = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'datasamples', dir_pkhash))) - - # check subtuple folder has been created and sym links exists - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) - self.assertTrue(os.path.islink(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk))) - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'LABEL_0024900.csv'))) - self.assertTrue(os.path.exists(os.path.join(MEDIA_ROOT, 'subtuple/bar/data', data_sample.pk, 'IMG_0024900.jpg'))) - - def test_put_model(self): - - model_content = self.model.read().encode() - - traintupleKey = compute_hash(model_content) - model_hash = compute_hash(model_content, traintupleKey) - model_type = 'model' - subtuple = {'key': model_hash, model_type: {'hash': model_hash, 'traintupleKey': traintupleKey}} - - model_directory = os.path.join(self.subtuple_path, 'model') - create_directory(model_directory) - put_model(subtuple, self.subtuple_path, model_content) - - model_path = os.path.join(model_directory, traintupleKey) - self.assertTrue(os.path.exists(model_path)) - - os.rename(model_path, model_path + '-local') - with mock.patch('substrapp.models.Model.objects.get') as mget: - mget.return_value = FakeModel(model_path + '-local') - put_model(subtuple, self.subtuple_path, model_content) - self.assertTrue(os.path.exists(model_path)) - - with mock.patch('substrapp.models.Model.objects.get') as mget: - mget.return_value = FakeModel(model_path) - with self.assertRaises(Exception): - put_model({'model': {'hash': 'fail-hash'}}, self.subtuple_path, model_content) - - def test_put_models(self): - - model_content = self.model.read().encode() - models_content = [model_content, model_content + b', -2.0'] - - traintupleKey = compute_hash(models_content[0]) - model_hash = compute_hash(models_content[0], traintupleKey) - - traintupleKey2 = compute_hash(models_content[1]) - model_hash2 = compute_hash(models_content[1], traintupleKey2) - - model_path = os.path.join(self.subtuple_path, 'model', traintupleKey) - model_path2 = os.path.join(self.subtuple_path, 'model', traintupleKey2) - - model_type = 'inModels' - subtuple = {model_type: [{'hash': model_hash, 'traintupleKey': traintupleKey}, - {'hash': model_hash2, 'traintupleKey': traintupleKey2}]} - - model_directory = os.path.join(self.subtuple_path, 'model/') - - create_directory(model_directory) - put_models(subtuple, self.subtuple_path, models_content) - - self.assertTrue(os.path.exists(model_path)) - self.assertTrue(os.path.exists(model_path2)) - - os.rename(model_path, model_path + '-local') - os.rename(model_path2, model_path2 + '-local') - - with mock.patch('substrapp.models.Model.objects.get') as mget: - mget.side_effect = [FakeModel(model_path + '-local'), FakeModel(model_path2 + '-local')] - put_models(subtuple, self.subtuple_path, models_content) - - self.assertTrue(os.path.exists(model_path)) - self.assertTrue(os.path.exists(model_path2)) - - with mock.patch('substrapp.models.Model.objects.get') as mget: - mget.return_value = FakeModel(model_path) - with self.assertRaises(Exception): - put_models({'inModels': [{'hash': 'hash'}]}, self.subtuple_path, model_content) - - def test_get_model(self): - model_content = self.model.read().encode() - traintupleKey = compute_hash(model_content) - model_hash = compute_hash(model_content, traintupleKey) - model_type = 'model' - subtuple = {model_type: {'hash': model_hash, 'traintupleKey': traintupleKey}} - - with mock.patch('substrapp.tasks.get_remote_file') as mget_remote_file: - mget_remote_file.return_value = model_content, model_hash - model_content, model_hash = get_model(subtuple) - - self.assertIsNotNone(model_content) - self.assertIsNotNone(model_hash) - - def test_get_models(self): - model_content = self.model.read().encode() - models_content = [model_content, model_content + b', -2.0'] - - traintupleKey = compute_hash(models_content[0]) - model_hash = compute_hash(models_content[0], traintupleKey) - - traintupleKey2 = compute_hash(models_content[1]) - model_hash2 = compute_hash(models_content[1], traintupleKey2) - - models_hash = [model_hash, model_hash2] - model_type = 'inModels' - subtuple = {model_type: [{'hash': model_hash, 'traintupleKey': traintupleKey}, - {'hash': model_hash2, 'traintupleKey': traintupleKey2}]} - - with mock.patch('substrapp.tasks.get_remote_file') as mget_remote_file: - mget_remote_file.side_effect = [[models_content[0], models_hash[0]], - [models_content[1], models_hash[1]]] - models_content_res, models_hash_res = get_models(subtuple) - - self.assertEqual(models_content_res, models_content) - self.assertIsNotNone(models_hash_res, models_hash) - - def test_get_algo(self): - algo_content = self.algo.read() - algo_hash = get_hash(self.algo) - - with mock.patch('substrapp.tasks.get_remote_file') as mget_remote_file: - mget_remote_file.return_value = algo_content, algo_hash - self.assertEqual((algo_content, algo_hash), get_algo({'algo': ''})) - - def test_get_objective(self): - metrics_content = self.script.read() - objective_hash = get_hash(self.script) - - with mock.patch('substrapp.models.Objective.objects.get') as mget, \ - mock.patch('substrapp.tasks.get_remote_file') as mget_remote_file, \ - mock.patch('substrapp.models.Objective.objects.update_or_create') as mupdate_or_create: - - mget.return_value = FakeObjective() - mget_remote_file.return_value = metrics_content, objective_hash - mupdate_or_create.return_value = FakeObjective(), True - - objective = get_objective({'objective': {'hash': objective_hash, - 'metrics': ''}}) - self.assertTrue(isinstance(objective, FakeObjective)) - - def test_compute_docker(self): - cpu_set, gpu_set = None, None - client = docker.from_env() - - dockerfile_path = os.path.join(self.subtuple_path, 'Dockerfile') - with open(dockerfile_path, 'w') as f: - f.write('FROM library/hello-world') - - hash_docker = uuid.uuid4().hex - result = compute_docker(client, self.ResourcesManager, - self.subtuple_path, 'test_compute_docker_' + hash_docker, - 'test_compute_docker_name_' + hash_docker, None, None) - - self.assertIsNone(cpu_set) - self.assertIsNone(gpu_set) - - self.assertIn('CPU', result) - self.assertIn('GPU', result) - self.assertIn('Mem', result) - self.assertIn('GPU Mem', result) - - def test_build_subtuple_folders(self): - with mock.patch('substrapp.tasks.getattr') as getattr: - getattr.return_value = self.subtuple_path - - subtuple_key = 'test1234' - subtuple = {'key': subtuple_key} - subtuple_directory = build_subtuple_folders(subtuple) - - self.assertTrue(os.path.exists(subtuple_directory)) - self.assertEqual(os.path.join(self.subtuple_path, f'subtuple/{subtuple["key"]}'), subtuple_directory) - - for root, dirs, files in os.walk(subtuple_directory): - nb_subfolders = len(dirs) - - self.assertTrue(5, nb_subfolders) - - def test_prepareTasks(self): - - class FakeSettings(object): - def __init__(self): - self.LEDGER = {'signcert': 'signcert', - 'org': 'owkin', - 'peer': 'peer'} - - self.MEDIA_ROOT = MEDIA_ROOT - - subtuple = [{'key': 'subtuple_test'}] - - with mock.patch('substrapp.tasks.settings') as msettings, \ - mock.patch('substrapp.tasks.get_hash') as mget_hash, \ - mock.patch('substrapp.tasks.queryLedger') as mqueryLedger, \ - mock.patch('substrapp.tasks.get_objective') as mget_objective, \ - mock.patch('substrapp.tasks.get_algo') as mget_algo, \ - mock.patch('substrapp.tasks.get_model') as mget_model, \ - mock.patch('substrapp.tasks.build_subtuple_folders') as mbuild_subtuple_folders, \ - mock.patch('substrapp.tasks.put_opener') as mput_opener, \ - mock.patch('substrapp.tasks.put_data_sample') as mput_data_sample, \ - mock.patch('substrapp.tasks.put_metric') as mput_metric, \ - mock.patch('substrapp.tasks.put_algo') as mput_algo, \ - mock.patch('substrapp.tasks.put_model') as mput_model: - - msettings.return_value = FakeSettings() - mget_hash.return_value = 'owkinhash' - mqueryLedger.return_value = subtuple, 200 - mget_objective.return_value = 'objective' - mget_algo.return_value = 'algo', 'algo_hash' - mget_model.return_value = 'model', 'model_hash' - mbuild_subtuple_folders.return_value = MEDIA_ROOT - mput_opener.return_value = 'opener' - mput_data_sample.return_value = 'data' - mput_metric.return_value = 'metric' - mput_algo.return_value = 'algo' - mput_model.return_value = 'model' - - with mock.patch('substrapp.tasks.queryLedger') as mqueryLedger: - mqueryLedger.return_value = 'data', 404 - prepareTask('traintuple', 'inModels') - - with mock.patch('substrapp.tasks.invokeLedger') as minvokeLedger, \ - mock.patch('substrapp.tasks.computeTask.apply_async') as mapply_async: - minvokeLedger.return_value = 'data', 201 - mapply_async.return_value = 'doTask' - prepareTask('traintuple', 'inModels') - - def test_doTask(self): - - class FakeSettings(object): - def __init__(self): - self.LEDGER = {'signcert': 'signcert', - 'org': 'owkin', - 'peer': 'peer'} - - self.MEDIA_ROOT = MEDIA_ROOT - - subtuple_key = 'test_owkin' - subtuple = {'key': subtuple_key, 'inModels': None} - subtuple_directory = build_subtuple_folders(subtuple) - - with mock.patch('substrapp.tasks.settings') as msettings, \ - mock.patch('substrapp.tasks.getattr') as mgetattr, \ - mock.patch('substrapp.tasks.invokeLedger') as minvokeLedger: - msettings.return_value = FakeSettings() - mgetattr.return_value = self.subtuple_path - minvokeLedger.return_value = 'data', 200 - - for name in ['opener', 'metrics']: - with open(os.path.join(subtuple_directory, f'{name}/{name}.py'), 'w') as f: - f.write('Hello World') - - perf = 0.3141592 - with open(os.path.join(subtuple_directory, 'pred/perf.json'), 'w') as f: - f.write(f'{{"all": {perf}}}') - - with open(os.path.join(subtuple_directory, 'model/model'), 'w') as f: - f.write("MODEL") - - with mock.patch('substrapp.tasks.compute_docker') as mcompute_docker: - mcompute_docker.return_value = 'DONE' - doTask(subtuple, 'traintuple') - - def test_computeTask(self): - - class FakeSettings(object): - def __init__(self): - self.LEDGER = {'signcert': 'signcert', - 'org': 'owkin', - 'peer': 'peer'} - - self.MEDIA_ROOT = MEDIA_ROOT - - subtuple_key = 'test_owkin' - subtuple = {'key': subtuple_key, 'inModels': None} - subtuple_directory = build_subtuple_folders(subtuple) - - with mock.patch('substrapp.tasks.settings') as msettings, \ - mock.patch('substrapp.tasks.getattr') as mgetattr, \ - mock.patch('substrapp.tasks.invokeLedger') as minvokeLedger: - msettings.return_value = FakeSettings() - mgetattr.return_value = self.subtuple_path - minvokeLedger.return_value = 'data', 200 - - for name in ['opener', 'metrics']: - with open(os.path.join(subtuple_directory, f'{name}/{name}.py'), 'w') as f: - f.write('Hello World') - - perf = 0.3141592 - with open(os.path.join(subtuple_directory, 'pred/perf.json'), 'w') as f: - f.write(f'{{"all": {perf}}}') - - with open(os.path.join(subtuple_directory, 'model/model'), 'w') as f: - f.write("MODEL") - - with mock.patch('substrapp.tasks.compute_docker') as mcompute_docker, \ - mock.patch('substrapp.tasks.prepareMaterials') as mprepareMaterials, \ - mock.patch('substrapp.tasks.invokeLedger') as minvokeLedger: - - mcompute_docker.return_value = 'DONE' - mprepareMaterials.return_value = 'DONE' - minvokeLedger.return_value = 'data', 201 - - computeTask('traintuple', subtuple, 'inModels', None) diff --git a/substrabac/substrapp/tests/tests_views.py b/substrabac/substrapp/tests/tests_views.py deleted file mode 100644 index 76d418adb..000000000 --- a/substrabac/substrapp/tests/tests_views.py +++ /dev/null @@ -1,1170 +0,0 @@ -import os -import shutil -import logging - -import mock - -from django.urls import reverse -from django.test import override_settings - -from rest_framework import status -from rest_framework.test import APITestCase - -from substrapp.views import DataManagerViewSet, TrainTupleViewSet, TestTupleViewSet, DataSampleViewSet - -from substrapp.serializers import LedgerDataSampleSerializer, LedgerObjectiveSerializer, LedgerAlgoSerializer - -from substrapp.views.utils import JsonException, ComputeHashMixin, getObjectFromLedger -from substrapp.views.datasample import path_leaf, compute_dryrun as data_sample_compute_dryrun -from substrapp.views.objective import compute_dryrun as objective_compute_dryrun -from substrapp.utils import compute_hash, get_hash - -from substrapp.models import DataManager - -from .common import get_sample_objective, get_sample_datamanager, get_sample_algo, get_sample_model -from .common import FakeAsyncResult, FakeRequest, FakeFilterDataManager, FakeTask, FakeDataManager -from .assets import objective, datamanager, algo, traintuple, model, testtuple - -MEDIA_ROOT = "/tmp/unittests_views/" - - -# APITestCase -class ViewTests(APITestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_data_sample_path_view(self): - self.assertEqual('tutu', path_leaf('/toto/tata/tutu')) - self.assertEqual('toto', path_leaf('/toto/')) - - def test_utils_ComputeHashMixin(self): - - compute = ComputeHashMixin() - myfile = 'toto' - key = 'tata' - - myfilehash = compute_hash(myfile) - myfilehashwithkey = compute_hash(myfile, key) - - self.assertEqual(myfilehash, compute.compute_hash(myfile)) - self.assertEqual(myfilehashwithkey, compute.compute_hash(myfile, key)) - - def test_utils_getObjectFromLedger(self): - - with mock.patch('substrapp.views.utils.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK)] - data = getObjectFromLedger('', 'queryObjective') - - self.assertEqual(data, objective) - - with mock.patch('substrapp.views.utils.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [('', status.HTTP_400_BAD_REQUEST)] - with self.assertRaises(JsonException): - getObjectFromLedger('', 'queryAllObjective') - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(DRYRUN_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -@override_settings(DEFAULT_DOMAIN='https://localhost') -class ObjectiveViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.objective_description, self.objective_description_filename, \ - self.objective_metrics, self.objective_metrics_filename = get_sample_objective() - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_objective_list_empty(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(None, status.HTTP_200_OK), - (['ISIC'], status.HTTP_200_OK)] - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [['ISIC']]) - - def test_objective_list_filter_fail(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK)] - - search_params = '?search=challenERRORge' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertIn('Malformed search filters', r['message']) - - def test_objective_list_filter_name(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK)] - - search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 2) - - def test_objective_list_filter_metrics(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK)] - - search_params = '?search=objective%253Ametrics%253Amacro-average%2520recall' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), len(objective)) - - def test_objective_list_filter_datamanager(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK), - (datamanager, status.HTTP_200_OK)] - - search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_objective_list_filter_model(self): - url = reverse('substrapp:objective-list') - with mock.patch('substrapp.views.objective.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(objective, status.HTTP_200_OK), - (traintuple, status.HTTP_200_OK)] - - pkhash = model[0]['traintuple']['outModel']['hash'] - search_params = f'?search=model%253Ahash%253A{pkhash}' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_objective_retrieve(self): - url = reverse('substrapp:objective-list') - - with mock.patch('substrapp.views.objective.getObjectFromLedger') as mgetObjectFromLedger, \ - mock.patch('substrapp.views.objective.requests.get') as mrequestsget: - mgetObjectFromLedger.return_value = objective[0] - - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), - '../../fixtures/owkin/objectives/objective0/description.md'), 'rb') as f: - content = f.read() - - mrequestsget.return_value = FakeRequest(status=status.HTTP_200_OK, - content=content) - - search_params = f'{compute_hash(content)}/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(r, objective[0]) - - def test_objective_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - url = reverse('substrapp:objective-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch('substrapp.views.objective.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_objective_create(self): - url = reverse('substrapp:objective-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - description_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/description.md') - metrics_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/metrics.py') - - pkhash = get_hash(description_path) - - test_data_manager_key = get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py')) - - data = { - 'name': 'Simplified skin lesion classification', - 'description': open(description_path, 'rb'), - 'metrics_name': 'macro-average recall', - 'metrics': open(metrics_path, 'rb'), - 'permissions': 'all', - 'test_data_sample_keys': [ - "2d0f943aa81a9cb3fe84b162559ce6aff068ccb04e0cb284733b8f9d7e06517e", - "533ee6e7b9d8b247e7e853b24547f57e6ef351852bac0418f13a0666173448f1" - ], - 'test_data_manager_key': test_data_manager_key - } - - with mock.patch.object(LedgerObjectiveSerializer, 'create') as mcreate: - - mcreate.return_value = ({}, - status.HTTP_201_CREATED) - - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data['pkhash'], pkhash) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - data['description'].close() - data['metrics'].close() - - def test_objective_create_dryrun(self): - - url = reverse('substrapp:objective-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - description_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/description.md') - metrics_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/metrics.py') - - test_data_manager_key = get_hash(os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py')) - - data = { - 'name': 'Simplified skin lesion classification', - 'description': open(description_path, 'rb'), - 'metrics_name': 'macro-average recall', - 'metrics': open(metrics_path, 'rb'), - 'permissions': 'all', - 'test_data_sample_keys': [ - "2d0f943aa81a9cb3fe84b162559ce6aff068ccb04e0cb284733b8f9d7e06517e", - "533ee6e7b9d8b247e7e853b24547f57e6ef351852bac0418f13a0666173448f1" - ], - 'test_data_manager_key': test_data_manager_key, - 'dryrun': True - } - - with mock.patch('substrapp.views.objective.compute_dryrun.apply_async') as mdryrun_task: - - mdryrun_task.return_value = FakeTask('42') - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data['id'], '42') - self.assertEqual(response.data['message'], 'Your dry-run has been taken in account. You can follow the task execution on https://localhost/task/42/') - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - data['description'].close() - data['metrics'].close() - - def test_objective_compute_dryrun(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - metrics_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/metrics.py') - description_path = os.path.join(dir_path, '../../fixtures/owkin/objectives/objective0/description.md') - shutil.copy(metrics_path, os.path.join(MEDIA_ROOT, 'metrics.py')) - - opener_path = os.path.join(dir_path, '../../fixtures/owkin/datamanagers/datamanager0/opener.py') - - with open(opener_path, 'rb') as f: - opener_content = f.read() - - pkhash = get_hash(description_path) - - test_data_manager_key = compute_hash(opener_content) - - with mock.patch('substrapp.views.objective.getObjectFromLedger') as mdatamanager,\ - mock.patch('substrapp.views.objective.get_computed_hash') as mopener: - mdatamanager.return_value = {'opener': {'storageAddress': 'test'}} - mopener.return_value = (opener_content, pkhash) - objective_compute_dryrun(os.path.join(MEDIA_ROOT, 'metrics.py'), test_data_manager_key, pkhash) - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(DRYRUN_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class AlgoViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.algo, self.algo_filename = get_sample_algo() - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_algo_list_empty(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(None, status.HTTP_200_OK), - (['ISIC'], status.HTTP_200_OK)] - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [['ISIC']]) - - def test_algo_list_filter_fail(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(algo, status.HTTP_200_OK)] - - search_params = '?search=algERRORo' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertIn('Malformed search filters', r['message']) - - def test_algo_list_filter_name(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(algo, status.HTTP_200_OK)] - - search_params = '?search=algo%253Aname%253ALogistic%2520regression' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_algo_list_filter_datamanager(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(algo, status.HTTP_200_OK), - (datamanager, status.HTTP_200_OK)] - - search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), len(algo)) - - def test_algo_list_filter_objective(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(algo, status.HTTP_200_OK), - (objective, status.HTTP_200_OK)] - - search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 3) - - def test_algo_list_filter_model(self): - url = reverse('substrapp:algo-list') - with mock.patch('substrapp.views.algo.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(algo, status.HTTP_200_OK), - (traintuple, status.HTTP_200_OK)] - - pkhash = model[0]['traintuple']['outModel']['hash'] - search_params = f'?search=model%253Ahash%253A{pkhash}' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_algo_retrieve(self): - dir_path = os.path.dirname(os.path.realpath(__file__)) - algo_hash = get_hash(os.path.join(dir_path, '../../fixtures/chunantes/algos/algo4/algo.tar.gz')) - url = reverse('substrapp:algo-list') - algo_response = [a for a in algo if a['key'] == algo_hash][0] - with mock.patch('substrapp.views.algo.getObjectFromLedger') as mgetObjectFromLedger, \ - mock.patch('substrapp.views.algo.requests.get') as mrequestsget: - - with open(os.path.join(dir_path, - '../../fixtures/chunantes/algos/algo4/description.md'), 'rb') as f: - content = f.read() - mgetObjectFromLedger.return_value = algo_response - - mrequestsget.return_value = FakeRequest(status=status.HTTP_200_OK, - content=content) - - search_params = f'{algo_hash}/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(r, algo_response) - - def test_algo_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - url = reverse('substrapp:algo-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch('substrapp.views.algo.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_algo_create(self): - url = reverse('substrapp:algo-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - algo_path = os.path.join(dir_path, '../../fixtures/chunantes/algos/algo3/algo.tar.gz') - description_path = os.path.join(dir_path, '../../fixtures/chunantes/algos/algo3/description.md') - - pkhash = get_hash(algo_path) - - data = {'name': 'Logistic regression', - 'file': open(algo_path, 'rb'), - 'description': open(description_path, 'rb'), - 'objective_key': get_hash(os.path.join(dir_path, '../../fixtures/chunantes/objectives/objective0/description.md')), - 'permissions': 'all'} - - with mock.patch.object(LedgerAlgoSerializer, 'create') as mcreate: - - mcreate.return_value = ({}, - status.HTTP_201_CREATED) - - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data['pkhash'], pkhash) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - data['description'].close() - data['file'].close() - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class ModelViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.model, self.model_filename = get_sample_model() - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_model_list_empty(self): - url = reverse('substrapp:model-list') - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(None, status.HTTP_200_OK), - (['ISIC'], status.HTTP_200_OK)] - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [['ISIC']]) - - def test_model_list_filter_fail(self): - - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(model, status.HTTP_200_OK)] - - url = reverse('substrapp:model-list') - search_params = '?search=modeERRORl' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - self.assertIn('Malformed search filters', r['message']) - - def test_model_list_filter_hash(self): - - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(model, status.HTTP_200_OK)] - - pkhash = model[0]['traintuple']['outModel']['hash'] - url = reverse('substrapp:model-list') - search_params = f'?search=model%253Ahash%253A{pkhash}' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - self.assertEqual(len(r[0]), 1) - - def test_model_list_filter_datamanager(self): - url = reverse('substrapp:model-list') - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(model, status.HTTP_200_OK), - (datamanager, status.HTTP_200_OK)] - - search_params = '?search=dataset%253Aname%253AISIC%25202018' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_model_list_filter_objective(self): - url = reverse('substrapp:model-list') - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(model, status.HTTP_200_OK), - (objective, status.HTTP_200_OK)] - - search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_model_list_filter_algo(self): - url = reverse('substrapp:model-list') - with mock.patch('substrapp.views.model.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(model, status.HTTP_200_OK), - (algo, status.HTTP_200_OK)] - - search_params = '?search=algo%253Aname%253ALogistic%2520regression' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_model_retrieve(self): - - with mock.patch('substrapp.views.model.getObjectFromLedger') as mgetObjectFromLedger, \ - mock.patch('substrapp.views.model.requests.get') as mrequestsget, \ - mock.patch('substrapp.views.model.ModelViewSet.compute_hash') as mcomputed_hash: - mgetObjectFromLedger.return_value = model[0] - - mrequestsget.return_value = FakeRequest(status=status.HTTP_200_OK, - content=self.model.read().encode()) - - mcomputed_hash.return_value = model[0]['traintuple']['outModel']['hash'] - - url = reverse('substrapp:model-list') - search_params = model[0]['traintuple']['outModel']['hash'] + '/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - self.assertEqual(r, model[0]) - - def test_model_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - url = reverse('substrapp:model-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch('substrapp.views.model.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class DataManagerViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.data_description, self.data_description_filename, \ - self.data_data_opener, self.data_opener_filename = get_sample_datamanager() - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_datamanager_list_empty(self): - url = reverse('substrapp:data_manager-list') - with mock.patch('substrapp.views.datamanager.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(None, status.HTTP_200_OK), - (['ISIC'], status.HTTP_200_OK)] - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [['ISIC']]) - - def test_datamanager_list_filter_fail(self): - url = reverse('substrapp:data_manager-list') - with mock.patch('substrapp.views.datamanager.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(datamanager, status.HTTP_200_OK)] - - search_params = '?search=dataseERRORt' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertIn('Malformed search filters', r['message']) - - def test_datamanager_list_filter_name(self): - url = reverse('substrapp:data_manager-list') - with mock.patch('substrapp.views.datamanager.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(datamanager, status.HTTP_200_OK)] - - search_params = '?search=dataset%253Aname%253ASimplified%2520ISIC%25202018' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 1) - - def test_datamanager_list_filter_objective(self): - url = reverse('substrapp:data_manager-list') - with mock.patch('substrapp.views.datamanager.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(datamanager, status.HTTP_200_OK), - (objective, status.HTTP_200_OK)] - - search_params = '?search=objective%253Aname%253ASkin%2520Lesion%2520Classification%2520Objective' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 2) - - def test_datamanager_list_filter_model(self): - url = reverse('substrapp:data_manager-list') - with mock.patch('substrapp.views.datamanager.queryLedger') as mqueryLedger: - mqueryLedger.side_effect = [(datamanager, status.HTTP_200_OK), - (traintuple, status.HTTP_200_OK)] - pkhash = model[0]['traintuple']['outModel']['hash'] - search_params = f'?search=model%253Ahash%253A{pkhash}' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(len(r[0]), 2) - - def test_datamanager_retrieve(self): - url = reverse('substrapp:data_manager-list') - datamanager_response = [d for d in datamanager if d['key'] == '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7'][0] - with mock.patch.object(DataManagerViewSet, 'getObjectFromLedger') as mgetObjectFromLedger, \ - mock.patch('substrapp.views.datamanager.requests.get') as mrequestsget: - mgetObjectFromLedger.return_value = datamanager_response - - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), - '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'), 'rb') as f: - opener_content = f.read() - - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), - '../../fixtures/chunantes/datamanagers/datamanager0/description.md'), 'rb') as f: - description_content = f.read() - - mrequestsget.side_effect = [FakeRequest(status=status.HTTP_200_OK, - content=opener_content), - FakeRequest(status=status.HTTP_200_OK, - content=description_content)] - - search_params = '615ce631b93c185b492dfc97ed5dea27430d871fa4e50678bab3c79ce2ec6cb7/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - - self.assertEqual(r, datamanager_response) - - def test_datamanager_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - url = reverse('substrapp:data_manager-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch.object(DataManagerViewSet, 'getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_datamanager_create_dryrun(self): - url = reverse('substrapp:data_manager-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - files = {'data_opener': open(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'), - 'rb'), - 'description': open(os.path.join(dir_path, - '../../fixtures/chunantes/datamanagers/datamanager0/description.md'), - 'rb')} - - data = { - 'name': 'ISIC 2018', - 'type': 'Images', - 'permissions': 'all', - 'dryrun': True - } - - response = self.client.post(url, {**data, **files}, format='multipart', **self.extra) - self.assertEqual(response.data, {'message': f'Your data opener is valid. You can remove the dryrun option.'}) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - for x in files: - files[x].close() - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class TraintupleViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_traintuple_queryset(self): - traintuple_view = TrainTupleViewSet() - self.assertFalse(traintuple_view.get_queryset()) - - def test_traintuple_list_empty(self): - url = reverse('substrapp:traintuple-list') - with mock.patch('substrapp.views.traintuple.queryLedger') as mqueryLedger: - mqueryLedger.return_value = ([[]], status.HTTP_200_OK) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - def test_traintuple_retrieve(self): - - with mock.patch.object(TrainTupleViewSet, 'getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.return_value = traintuple[0] - url = reverse('substrapp:traintuple-list') - search_params = 'c164f4c714a78c7e2ba2016de231cdd41e3eac61289e08c1f711e74915a0868f/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - self.assertEqual(r, traintuple[0]) - - def test_traintuple_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - url = reverse('substrapp:traintuple-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch.object(TrainTupleViewSet, 'getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class TesttupleViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_testtuple_queryset(self): - testtuple_view = TestTupleViewSet() - self.assertFalse(testtuple_view.get_queryset()) - - def test_testtuple_list_empty(self): - url = reverse('substrapp:testtuple-list') - with mock.patch('substrapp.views.testtuple.queryLedger') as mqueryLedger: - mqueryLedger.return_value = ([[]], status.HTTP_200_OK) - - response = self.client.get(url, **self.extra) - r = response.json() - self.assertEqual(r, [[]]) - - def test_testtuple_retrieve(self): - - with mock.patch('substrapp.views.testtuple.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.return_value = testtuple[0] - url = reverse('substrapp:testtuple-list') - search_params = 'c164f4c714a78c7e2ba2016de231cdd41e3eac61289e08c1f711e74915a0868f/' - response = self.client.get(url + search_params, **self.extra) - r = response.json() - self.assertEqual(r, testtuple[0]) - - def test_testtuple_retrieve_fail(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - url = reverse('substrapp:testtuple-list') - - # PK hash < 64 chars - search_params = '42303efa663015e729159833a12ffb510ff/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - # PK hash not hexa - search_params = 'X' * 64 + '/' - response = self.client.get(url + search_params, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - with mock.patch('substrapp.views.testtuple.getObjectFromLedger') as mgetObjectFromLedger: - mgetObjectFromLedger.side_effect = JsonException('TEST') - - search_params = f'{get_hash(os.path.join(dir_path, "../../fixtures/owkin/objectives/objective0/description.md"))}/' - response = self.client.get(url + search_params, **self.extra) - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -class TaskViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_task_retrieve(self): - - url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) - with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: - mAsyncResult.return_value = FakeAsyncResult(status='SUCCESS') - response = self.client.get(url, **self.extra) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_task_retrieve_fail(self): - url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) - with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: - mAsyncResult.return_value = FakeAsyncResult() - response = self.client.get(url, **self.extra) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_task_retrieve_pending(self): - url = reverse('substrapp:task-detail', kwargs={'pk': 'pk'}) - with mock.patch('substrapp.views.task.AsyncResult') as mAsyncResult: - mAsyncResult.return_value = FakeAsyncResult(status='PENDING', successful=False) - response = self.client.get(url, **self.extra) - self.assertEqual(response.data['message'], - 'Task is either waiting, does not exist in this context or has been removed after 24h') - - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -# APITestCase -@override_settings(MEDIA_ROOT=MEDIA_ROOT) -@override_settings(DRYRUN_ROOT=MEDIA_ROOT) -@override_settings(SITE_HOST='localhost') -@override_settings(LEDGER={'name': 'test-org', 'peer': 'test-peer'}) -@override_settings(DEFAULT_DOMAIN='https://localhost') -class DataViewTests(APITestCase): - - def setUp(self): - if not os.path.exists(MEDIA_ROOT): - os.makedirs(MEDIA_ROOT) - - self.data_description, self.data_description_filename, \ - self.data_data_opener, self.data_opener_filename = get_sample_datamanager() - - self.extra = { - 'HTTP_ACCEPT': 'application/json;version=0.0' - } - - self.logger = logging.getLogger('django.request') - self.previous_level = self.logger.getEffectiveLevel() - self.logger.setLevel(logging.ERROR) - - def tearDown(self): - try: - shutil.rmtree(MEDIA_ROOT) - except FileNotFoundError: - pass - - self.logger.setLevel(self.previous_level) - - def test_data_create_bulk(self): - url = reverse('substrapp:data_sample-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - data_path1 = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample1/0024700.zip') - data_path2 = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample0/0024899.zip') - - pkhash1 = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' - pkhash2 = '30f6c797e277451b0a08da7119ed86fb2986fa7fab2258bf3edbd9f1752ed553' - - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] - - data = { - 'files': [path_leaf(data_path1), path_leaf(data_path2)], - path_leaf(data_path1): open(data_path1, 'rb'), - path_leaf(data_path2): open(data_path2, 'rb'), - 'data_manager_keys': data_manager_keys, - 'test_only': False - } - - with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - - mdatamanager.return_value = FakeFilterDataManager(1) - mcreate.return_value = ({'keys': [pkhash1, pkhash2]}, - status.HTTP_201_CREATED) - response = self.client.post(url, data=data, format='multipart', **self.extra) - self.assertEqual([r['pkhash'] for r in response.data], [pkhash1, pkhash2]) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - for x in data['files']: - data[x].close() - - def test_data_create_bulk_dryrun(self): - url = reverse('substrapp:data_sample-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - data_path1 = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample1/0024700.zip') - data_path2 = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample0/0024899.zip') - - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] - - data = { - 'files': [path_leaf(data_path1), path_leaf(data_path2)], - path_leaf(data_path1): open(data_path1, 'rb'), - path_leaf(data_path2): open(data_path2, 'rb'), - 'data_manager_keys': data_manager_keys, - 'test_only': False, - 'dryrun': True - } - - with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ - mock.patch.object(DataSampleViewSet, 'dryrun_task') as mdryrun_task: - - mdatamanager.return_value = FakeFilterDataManager(1) - mdryrun_task.return_value = (FakeTask('42'), 'Your dry-run has been taken in account. You can follow the task execution on localhost') - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data['id'], '42') - self.assertEqual(response.data['message'], 'Your dry-run has been taken in account. You can follow the task execution on localhost') - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - for x in data['files']: - data[x].close() - - def test_data_create(self): - url = reverse('substrapp:data_sample-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - data_path = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample1/0024700.zip') - - pkhash = '24fb12ff87485f6b0bc5349e5bf7f36ccca4eb1353395417fdae7d8d787f178c' - - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] - - data = { - 'file': open(data_path, 'rb'), - 'data_manager_keys': data_manager_keys, - 'test_only': False - } - - with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ - mock.patch.object(LedgerDataSampleSerializer, 'create') as mcreate: - - mdatamanager.return_value = FakeFilterDataManager(1) - mcreate.return_value = ({'keys': [pkhash]}, - status.HTTP_201_CREATED) - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data[0]['pkhash'], pkhash) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - data['file'].close() - - def test_data_create_dryrun(self): - - url = reverse('substrapp:data_sample-list') - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - data_path = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample1/0024700.zip') - - data_manager_keys = [get_hash(os.path.join(dir_path, '../../fixtures/chunantes/datamanagers/datamanager0/opener.py'))] - - data = { - 'file': open(data_path, 'rb'), - 'data_manager_keys': data_manager_keys, - 'test_only': False, - 'dryrun': True - } - - with mock.patch.object(DataManager.objects, 'filter') as mdatamanager, \ - mock.patch.object(DataSampleViewSet, 'dryrun_task') as mdryrun_task: - - mdatamanager.return_value = FakeFilterDataManager(1) - mdryrun_task.return_value = (FakeTask('42'), 'Your dry-run has been taken in account. You can follow the task execution on localhost') - response = self.client.post(url, data=data, format='multipart', **self.extra) - - self.assertEqual(response.data['id'], '42') - self.assertEqual(response.data['message'], 'Your dry-run has been taken in account. You can follow the task execution on localhost') - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) - - data['file'].close() - - def test_data_sample_compute_dryrun(self): - - dir_path = os.path.dirname(os.path.realpath(__file__)) - - data_path = os.path.join(dir_path, '../../fixtures/chunantes/datasamples/datasample1/0024700.zip') - - shutil.copy(data_path, os.path.join(MEDIA_ROOT, '0024700.zip')) - - opener_path = os.path.join(dir_path, '../../fixtures/chunantes/datamanagers/datamanager0/opener.py') - - pkhash = '62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a' - - data = { - 'filepath': os.path.join(MEDIA_ROOT, '0024700.zip'), - 'pkhash': pkhash, - } - - data_files = [data] - data_manager_keys = [get_hash(opener_path)] - - with mock.patch.object(DataManager.objects, 'get') as mdatamanager: - mdatamanager.return_value = FakeDataManager(opener_path) - data_sample_compute_dryrun(data_files, data_manager_keys) diff --git a/substrabac/substrapp/utils.py b/substrabac/substrapp/utils.py deleted file mode 100644 index 5f48345bd..000000000 --- a/substrabac/substrapp/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -import io -import hashlib -import json -import logging -import os -import tempfile -from os.path import isfile, isdir - -import requests -import subprocess -import tarfile -import zipfile - -from checksumdir import dirhash -from rest_framework import status - -from substrabac.settings.common import PROJECT_ROOT -from django.conf import settings - -LEDGER = getattr(settings, 'LEDGER', None) - - -def clean_env_variables(): - os.environ.pop('FABRIC_CFG_PATH', None) - os.environ.pop('CORE_PEER_MSPCONFIGPATH', None) - os.environ.pop('CORE_PEER_ADDRESS', None) - -####### -# /!\ # -####### - -# careful, passing invoke parameters to queryLedger will NOT fail - - -def queryLedger(options): - args = options['args'] - - channel_name = LEDGER['channel_name'] - chaincode_name = LEDGER['chaincode_name'] - core_peer_mspconfigpath = LEDGER['core_peer_mspconfigpath'] - peer = LEDGER['peer'] - - peer_port = peer["port"][os.environ.get('SUBSTRABAC_PEER_PORT', 'external')] - - # update config path for using right core.yaml and override msp config path - os.environ['FABRIC_CFG_PATH'] = os.environ.get('FABRIC_CFG_PATH_ENV', peer['docker_core_dir']) - os.environ['CORE_PEER_MSPCONFIGPATH'] = os.environ.get('CORE_PEER_MSPCONFIGPATH_ENV', core_peer_mspconfigpath) - os.environ['CORE_PEER_ADDRESS'] = os.environ.get('CORE_PEER_ADDRESS_ENV', f'{peer["host"]}:{peer_port}') - - print(f'Querying chaincode in the channel \'{channel_name}\' on the peer \'{peer["host"]}\' ...', flush=True) - - output = subprocess.run([os.path.join(PROJECT_ROOT, '../bin/peer'), - 'chaincode', 'query', - '-x', - '-C', channel_name, - '-n', chaincode_name, - '-c', args], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - st = status.HTTP_200_OK - data = output.stdout.decode('utf-8') - if data: - # json transformation if needed - try: - data = json.loads(bytes.fromhex(data.rstrip()).decode('utf-8')) - except: - logging.error('Failed to json parse hexadecimal response in query') - - msg = f'Query of channel \'{channel_name}\' on the peer \'{peer["host"]}\' was successful\n' - print(msg, flush=True) - else: - try: - msg = output.stderr.decode('utf-8').split('Error')[-1].split('\n')[0] - data = {'message': msg} - except: - msg = output.stderr.decode('utf-8') - data = {'message': msg} - finally: - st = status.HTTP_400_BAD_REQUEST - if 'access denied' in msg: - st = status.HTTP_403_FORBIDDEN - elif 'no element with key' in msg: - st = status.HTTP_404_NOT_FOUND - - clean_env_variables() - - return data, st - - -def invokeLedger(options, sync=False): - args = options['args'] - - channel_name = LEDGER['channel_name'] - chaincode_name = LEDGER['chaincode_name'] - core_peer_mspconfigpath = LEDGER['core_peer_mspconfigpath'] - peer = LEDGER['peer'] - peer_port = peer["port"][os.environ.get('SUBSTRABAC_PEER_PORT', 'external')] - - orderer = LEDGER['orderer'] - orderer_ca_file = orderer['ca'] - peer_key_file = peer['clientKey'] - peer_cert_file = peer['clientCert'] - - # update config path for using right core.yaml and override msp config path - os.environ['FABRIC_CFG_PATH'] = os.environ.get('FABRIC_CFG_PATH_ENV', peer['docker_core_dir']) - os.environ['CORE_PEER_MSPCONFIGPATH'] = os.environ.get('CORE_PEER_MSPCONFIGPATH_ENV', core_peer_mspconfigpath) - os.environ['CORE_PEER_ADDRESS'] = os.environ.get('CORE_PEER_ADDRESS_ENV', f'{peer["host"]}:{peer_port}') - - print(f'Sending invoke transaction to {peer["host"]} ...', flush=True) - - cmd = [os.path.join(PROJECT_ROOT, '../bin/peer'), - 'chaincode', 'invoke', - '-C', channel_name, - '-n', chaincode_name, - '-c', args, - '-o', f'{orderer["host"]}:{orderer["port"]}', - '--cafile', orderer_ca_file, - '--tls', - '--clientauth', - '--keyfile', peer_key_file, - '--certfile', peer_cert_file] - - if sync: - cmd += ['--waitForEvent', '--waitForEventTimeout', '45s'] - - output = subprocess.run(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - st = status.HTTP_201_CREATED - data = output.stdout.decode('utf-8') - - if not data: - msg = output.stderr.decode('utf-8') - data = {'message': msg} - - if 'Error' in msg or 'ERRO' in msg: - # https://github.com/hyperledger/fabric/blob/eca1b14b7e3453a5d32296af79cc7bad10c7673b/peer/chaincode/common.go - if "timed out waiting for txid on all peers" in msg or "failed to receive txid on all peers" in msg: - st = status.HTTP_408_REQUEST_TIMEOUT - else: - st = status.HTTP_400_BAD_REQUEST - elif 'access denied' in msg or 'authentication handshake failed' in msg: - st = status.HTTP_403_FORBIDDEN - elif 'Chaincode invoke successful' in msg: - st = status.HTTP_201_CREATED - try: - msg = msg.split('result: status:')[1].split('\n')[0].split('payload:')[1].strip().strip('"') - except: - pass - else: - msg = json.loads(msg.encode('utf-8').decode('unicode_escape')) - msg = msg.get('key', msg.get('keys')) # get pkhash - finally: - data = {'pkhash': msg} - - clean_env_variables() - - return data, st - - -def get_dir_hash(archive_content): - with tempfile.TemporaryDirectory() as temp_dir: - try: - content = archive_content.read() - archive_content.seek(0) - uncompress_content(content, temp_dir) - except Exception as e: - logging.error(e) - raise e - else: - return dirhash(temp_dir, 'sha256') - - -def get_hash(file, key=None): - if file is None: - return '' - else: - if isinstance(file, (str, bytes, os.PathLike)): - if isfile(file): - with open(file, 'rb') as f: - data = f.read() - elif isdir(file): - return dirhash(file, 'sha256') - else: - return '' - else: - openedfile = file.open() - data = openedfile.read() - openedfile.seek(0) - - return compute_hash(data, key) - - -def compute_hash(bytes, key=None): - sha256_hash = hashlib.sha256() - - if isinstance(bytes, str): - bytes = bytes.encode() - - if key is not None and isinstance(key, str): - bytes += key.encode() - - sha256_hash.update(bytes) - - return sha256_hash.hexdigest() - - -def get_computed_hash(url, key=None): - username = getattr(settings, 'BASICAUTH_USERNAME', None) - password = getattr(settings, 'BASICAUTH_PASSWORD', None) - - kwargs = {} - - if username is not None and password is not None: - kwargs.update({'auth': (username, password)}) - - if settings.DEBUG: - kwargs.update({'verify': False}) - - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}, **kwargs) - except: - raise Exception(f'Failed to check hash due to failed file fetching {url}') - else: - if r.status_code != 200: - raise Exception( - f'Url: {url} to fetch file returned status code: {r.status_code}') - - computedHash = compute_hash(r.content, key) - - return r.content, computedHash - - -def get_remote_file(object, key=None): - content, computed_hash = get_computed_hash(object['storageAddress'], key) - - if computed_hash != object['hash']: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - return content, computed_hash - - -def create_directory(directory): - if not os.path.exists(directory): - os.makedirs(directory) - - -def uncompress_path(archive_path, to_directory): - if zipfile.is_zipfile(archive_path): - zip_ref = zipfile.ZipFile(archive_path, 'r') - zip_ref.extractall(to_directory) - zip_ref.close() - elif tarfile.is_tarfile(archive_path): - tar = tarfile.open(archive_path, 'r:*') - tar.extractall(to_directory) - tar.close() - else: - raise Exception('Archive must be zip or tar.gz') - - -def uncompress_content(archive_content, to_directory): - if zipfile.is_zipfile(io.BytesIO(archive_content)): - zip_ref = zipfile.ZipFile(io.BytesIO(archive_content)) - zip_ref.extractall(to_directory) - zip_ref.close() - else: - try: - tar = tarfile.open(fileobj=io.BytesIO(archive_content)) - tar.extractall(to_directory) - tar.close() - except tarfile.TarError: - raise Exception('Archive must be zip or tar.*') diff --git a/substrabac/substrapp/views/__init__.py b/substrabac/substrapp/views/__init__.py deleted file mode 100644 index d3149b744..000000000 --- a/substrabac/substrapp/views/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# encoding: utf-8 - -from .datasample import DataSampleViewSet -from .datamanager import DataManagerViewSet -from .objective import ObjectiveViewSet -from .model import ModelViewSet -from .algo import AlgoViewSet -from .traintuple import TrainTupleViewSet -from .testtuple import TestTupleViewSet -from .task import TaskViewSet - -__all__ = ['DataSampleViewSet', 'DataManagerViewSet', 'ObjectiveViewSet', 'ModelViewSet', - 'AlgoViewSet', 'TrainTupleViewSet', 'TestTupleViewSet', - 'TaskViewSet'] diff --git a/substrabac/substrapp/views/algo.py b/substrabac/substrapp/views/algo.py deleted file mode 100644 index bad64fffe..000000000 --- a/substrabac/substrapp/views/algo.py +++ /dev/null @@ -1,221 +0,0 @@ -import tempfile - -import requests - -from django.http import Http404 -from rest_framework import status, mixins -from rest_framework.decorators import action -from rest_framework.exceptions import ValidationError -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -from substrapp.models import Algo -from substrapp.serializers import LedgerAlgoSerializer, AlgoSerializer -from substrapp.utils import queryLedger, get_hash -from substrapp.views.utils import get_filters, getObjectFromLedger, ComputeHashMixin, ManageFileMixin, JsonException, find_primary_key_error - - -class AlgoViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - ComputeHashMixin, - ManageFileMixin, - GenericViewSet): - queryset = Algo.objects.all() - serializer_class = AlgoSerializer - ledger_query_call = 'queryAlgo' - - def perform_create(self, serializer): - return serializer.save() - - def create(self, request, *args, **kwargs): - data = request.data - - file = data.get('file') - pkhash = get_hash(file) - serializer = self.get_serializer(data={ - 'pkhash': pkhash, - 'file': file, - 'description': data.get('description') - }) - - try: - serializer.is_valid(raise_exception=True) - except Exception as e: - st = status.HTTP_400_BAD_REQUEST - if find_primary_key_error(e): - st = status.HTTP_409_CONFLICT - return Response({'message': e.args, 'pkhash': pkhash}, status=st) - else: - - # create on db - try: - instance = self.perform_create(serializer) - except Exception as exc: - return Response({'message': exc.args}, - status=status.HTTP_400_BAD_REQUEST) - else: - # init ledger serializer - ledger_serializer = LedgerAlgoSerializer(data={'name': data.get('name'), - 'permissions': data.get('permissions', 'all'), - 'instance': instance}, - context={'request': request}) - if not ledger_serializer.is_valid(): - # delete instance - instance.delete() - raise ValidationError(ledger_serializer.errors) - - # create on ledger - data, st = ledger_serializer.create(ledger_serializer.validated_data) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED, status.HTTP_408_REQUEST_TIMEOUT): - return Response(data, status=st) - - headers = self.get_success_headers(serializer.data) - d = dict(serializer.data) - d.update(data) - return Response(d, status=st, headers=headers) - - def create_or_update_algo(self, algo, pk): - try: - # get algo description from remote node - url = algo['description']['storageAddress'] - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}) # TODO pass cert - except: - raise Exception(f'Failed to fetch {url}') - else: - if r.status_code != 200: - raise Exception(f'end to end node report {r.text}') - - try: - computed_hash = self.compute_hash(r.content) - except Exception: - raise Exception('Failed to fetch description file') - else: - if computed_hash != algo['description']['hash']: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - f = tempfile.TemporaryFile() - f.write(r.content) - - # save/update objective in local db for later use - instance, created = Algo.objects.update_or_create(pkhash=pk, validated=True) - instance.description.save('description.md', f) - except Exception as e: - raise e - else: - return instance - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - else: - # get instance from remote node - error = None - instance = None - try: - data = getObjectFromLedger(pk, self.ledger_query_call) - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - try: - # try to get it from local db to check if description exists - instance = self.get_object() - except Http404: - try: - instance = self.create_or_update_algo(data, pk) - except Exception as e: - error = e - else: - # check if instance has description - if not instance.description: - try: - instance = self.create_or_update_algo(data, pk) - except Exception as e: - error = e - finally: - if error is not None: - return Response(str(error), status=status.HTTP_400_BAD_REQUEST) - - # do not give access to local files address - if instance is not None: - serializer = self.get_serializer(instance, fields=('owner', 'pkhash', 'creation_date', 'last_modified')) - data.update(serializer.data) - else: - data = {'message': 'Fail to get instance'} - - return Response(data, status=status.HTTP_200_OK) - - def list(self, request, *args, **kwargs): - # can modify result by interrogating `request.version` - - data, st = queryLedger({ - 'args': '{"Args":["queryAlgos"]}' - }) - - modelData = None - - # init list to return - if data is None: - data = [] - l = [data] - - if st == 200: - - # parse filters - query_params = request.query_params.get('search', None) - - if query_params is not None: - try: - filters = get_filters(query_params) - except Exception as exc: - return Response( - {'message': f'Malformed search filters {query_params}'}, - status=status.HTTP_400_BAD_REQUEST) - else: - # filtering, reinit l to empty array - l = [] - for idx, filter in enumerate(filters): - # init each list iteration to data - l.append(data) - for k, subfilters in filter.items(): - if k == 'algo': # filter by own key - for key, val in subfilters.items(): - l[idx] = [x for x in l[idx] if x[key] in val] - elif k == 'model': # select objectives used by outModel hash - if not modelData: - # TODO find a way to put this call in cache - modelData, st = queryLedger({ - 'args': '{"Args":["queryTraintuples"]}' - }) - if st != status.HTTP_200_OK: - return Response(modelData, status=st) - if modelData is None: - modelData = [] - - for key, val in subfilters.items(): - filteredData = [x for x in modelData if x['outModel'] is not None and x['outModel'][key] in val] - algoKeys = [x['algo']['hash'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['key'] in algoKeys] - - return Response(l, status=st) - - @action(detail=True) - def file(self, request, *args, **kwargs): - return self.manage_file('file') - - @action(detail=True) - def description(self, request, *args, **kwargs): - return self.manage_file('description') diff --git a/substrabac/substrapp/views/datamanager.py b/substrabac/substrapp/views/datamanager.py deleted file mode 100644 index 783697216..000000000 --- a/substrabac/substrapp/views/datamanager.py +++ /dev/null @@ -1,371 +0,0 @@ -import ast -import tempfile - -import requests -from django.conf import settings -from django.http import Http404 -from rest_framework import status, mixins -from rest_framework.decorators import action -from rest_framework.exceptions import ValidationError -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -# from hfc.fabric import Client -# cli = Client(net_profile="../network.json") -from substrapp.models import DataManager -from substrapp.serializers import DataManagerSerializer, LedgerDataManagerSerializer -from substrapp.serializers.ledger.datamanager.util import updateLedgerDataManager -from substrapp.serializers.ledger.datamanager.tasks import updateLedgerDataManagerAsync -from substrapp.utils import queryLedger, get_hash -from substrapp.views.utils import get_filters, ManageFileMixin, ComputeHashMixin, JsonException, find_primary_key_error - - -class DataManagerViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - ComputeHashMixin, - ManageFileMixin, - GenericViewSet): - queryset = DataManager.objects.all() - serializer_class = DataManagerSerializer - ledger_query_call = 'queryDataManager' - - def perform_create(self, serializer): - return serializer.save() - - def dryrun(self, data_opener): - - file = data_opener.open().read() - - try: - node = ast.parse(file) - except: - return Response({'message': f'Opener must be a valid python file, please review your opener file and the documentation.'}, - status=status.HTTP_400_BAD_REQUEST) - - imported_module_names = [m.name for e in node.body if isinstance(e, ast.Import) for m in e.names] - if 'substratools' not in imported_module_names: - return Response({'message': 'Opener must import substratools, please review your opener and the documentation.'}, - status=status.HTTP_400_BAD_REQUEST) - - return Response({'message': f'Your data opener is valid. You can remove the dryrun option.'}, - status=status.HTTP_200_OK) - - def create(self, request, *args, **kwargs): - data = request.data - - dryrun = data.get('dryrun', False) - - data_opener = data.get('data_opener') - - pkhash = get_hash(data_opener) - serializer = self.get_serializer(data={ - 'pkhash': pkhash, - 'data_opener': data_opener, - 'description': data.get('description'), - 'name': data.get('name'), - }) - - try: - serializer.is_valid(raise_exception=True) - except Exception as e: - st = status.HTTP_400_BAD_REQUEST - if find_primary_key_error(e): - st = status.HTTP_409_CONFLICT - return Response({'message': e.args, 'pkhash': pkhash}, status=st) - else: - if dryrun: - return self.dryrun(data_opener) - - # create on db - try: - instance = self.perform_create(serializer) - except Exception as e: - return Response({'message': e.args}, - status=status.HTTP_400_BAD_REQUEST) - else: - # init ledger serializer - ledger_serializer = LedgerDataManagerSerializer(data={'name': data.get('name'), - 'permissions': data.get('permissions'), - 'type': data.get('type'), - 'objective_keys': data.getlist('objective_keys'), - 'instance': instance}, - context={'request': request}) - - if not ledger_serializer.is_valid(): - # delete instance - instance.delete() - raise ValidationError(ledger_serializer.errors) - - # create on ledger - data, st = ledger_serializer.create(ledger_serializer.validated_data) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED, status.HTTP_408_REQUEST_TIMEOUT): - return Response(data, status=st) - - headers = self.get_success_headers(serializer.data) - d = dict(serializer.data) - d.update(data) - return Response(d, status=st, headers=headers) - - def create_or_update_datamanager(self, instance, datamanager, pk): - - # create instance if does not exist - if not instance: - instance, created = DataManager.objects.update_or_create(pkhash=pk, name=datamanager['name'], validated=True) - - if not instance.data_opener: - try: - url = datamanager['opener']['storageAddress'] - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}) - except: - raise Exception(f'Failed to fetch {url}') - else: - if r.status_code != 200: - raise Exception(f'end to end node report {r.text}') - - try: - computed_hash = self.compute_hash(r.content) - except Exception: - raise Exception('Failed to fetch opener file') - else: - if computed_hash != pk: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - f = tempfile.TemporaryFile() - f.write(r.content) - - # save/update data_opener in local db for later use - instance.data_opener.save('opener.py', f) - - except Exception as e: - raise e - - if not instance.description: - # do the same for description - url = datamanager['description']['storageAddress'] - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}) - except: - raise Exception(f'Failed to fetch {url}') - else: - if r.status_code != status.HTTP_200_OK: - raise Exception(f'end to end node report {r.text}') - - try: - computed_hash = self.compute_hash(r.content) - except Exception: - raise Exception('Failed to fetch description file') - else: - if computed_hash != datamanager['description']['hash']: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - f = tempfile.TemporaryFile() - f.write(r.content) - - # save/update description in local db for later use - instance.description.save('description.md', f) - - return instance - - def getObjectFromLedger(self, pk): - # get instance from remote node - data, st = queryLedger({ - 'args': f'{{"Args":["queryDataset", "{pk}"]}}' - }) - - if st == status.HTTP_404_NOT_FOUND: - raise Http404('Not found') - - if st != status.HTTP_200_OK: - raise JsonException(data) - - if data['permissions'] == 'all': - return data - else: - raise Exception('Not Allowed') - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - else: - # get instance from remote node - try: - data = self.getObjectFromLedger(pk) # datamanager use particular query to ledger - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - error = None - instance = None - try: - # try to get it from local db to check if description exists - instance = self.get_object() - except Http404: - try: - instance = self.create_or_update_datamanager(instance, data, pk) - except Exception as e: - error = e - else: - # check if instance has description or data_opener - if not instance.description or not instance.data_opener: - try: - instance = self.create_or_update_datamanager(instance, data, pk) - except Exception as e: - error = e - finally: - if error is not None: - return Response({'message': str(error)}, status=status.HTTP_400_BAD_REQUEST) - - # do not give access to local files address - if instance is not None: - serializer = self.get_serializer(instance, fields=('owner', 'pkhash', 'creation_date', 'last_modified')) - data.update(serializer.data) - else: - data = {'message': 'Fail to get instance'} - - return Response(data, status=status.HTTP_200_OK) - - def list(self, request, *args, **kwargs): - # can modify result by interrogating `request.version` - - data, st = queryLedger({ - 'args': '{"Args":["queryDataManagers"]}' - }) - objectiveData = None - algoData = None - modelData = None - - # init list to return - if data is None: - data = [] - l = [data] - - if st == 200: - - # parse filters - query_params = request.query_params.get('search', None) - - if query_params is not None: - try: - filters = get_filters(query_params) - except Exception as exc: - return Response( - {'message': f'Malformed search filters {query_params}'}, - status=status.HTTP_400_BAD_REQUEST) - else: - # filtering, reinit l to empty array - l = [] - for idx, filter in enumerate(filters): - # init each list iteration to data - l.append(data) - for k, subfilters in filter.items(): - if k == 'dataset': # filter by own key - for key, val in subfilters.items(): - l[idx] = [x for x in l[idx] if x[key] in val] - elif k == 'objective': # select objective used by these datamanagers - if not objectiveData: - # TODO find a way to put this call in cache - objectiveData, st = queryLedger({ - 'args': '{"Args":["queryObjectives"]}' - }) - if st != status.HTTP_200_OK: - return Response(objectiveData, status=st) - if objectiveData is None: - objectiveData = [] - - for key, val in subfilters.items(): - if key == 'metrics': # specific to nested metrics - filteredData = [x for x in objectiveData if x[key]['name'] in val] - else: - filteredData = [x for x in objectiveData if x[key] in val] - objectiveKeys = [x['key'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['objectiveKey'] in objectiveKeys] - elif k == 'model': # select objectives used by outModel hash - if not modelData: - # TODO find a way to put this call in cache - modelData, st = queryLedger({ - 'args': '{"Args":["queryTraintuples"]}' - }) - if st != status.HTTP_200_OK: - return Response(modelData, status=st) - if modelData is None: - modelData = [] - - for key, val in subfilters.items(): - filteredData = [x for x in modelData if x['outModel'] is not None and x['outModel'][key] in val] - objectiveKeys = [x['objective']['hash'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['objectiveKey'] in objectiveKeys] - - return Response(l, status=st) - - @action(methods=['post'], detail=True) - def update_ledger(self, request, *args, **kwargs): - - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, - status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, - status.HTTP_400_BAD_REQUEST) - else: - - data = request.data - objective_key = data.get('objective_key') - - if len(pk) != 64: - return Response({'message': f'Objective Key is wrong: {pk}'}, - status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Objective Key is wrong: {pk}'}, - status.HTTP_400_BAD_REQUEST) - else: - args = '"%(dataManagerKey)s", "%(objectiveKey)s"' % { - 'dataManagerKey': pk, - 'objectiveKey': objective_key, - } - - if getattr(settings, 'LEDGER_SYNC_ENABLED'): - data, st = updateLedgerDataManager(args, sync=True) - - # patch status for update - if st == status.HTTP_201_CREATED: - st = status.HTTP_200_OK - return Response(data, status=st) - else: - # use a celery task, as we are in an http request transaction - updateLedgerDataManagerAsync.delay(args) - data = { - 'message': 'The substra network has been notified for updating this DataManager' - } - st = status.HTTP_202_ACCEPTED - return Response(data, status=st) - - @action(detail=True) - def description(self, request, *args, **kwargs): - return self.manage_file('description') - - @action(detail=True) - def opener(self, request, *args, **kwargs): - return self.manage_file('data_opener') diff --git a/substrabac/substrapp/views/datasample.py b/substrabac/substrapp/views/datasample.py deleted file mode 100644 index 4b8cab446..000000000 --- a/substrabac/substrapp/views/datasample.py +++ /dev/null @@ -1,330 +0,0 @@ -import logging -from os.path import normpath - -import docker -import os -import ntpath -import uuid - -from checksumdir import dirhash -from django.conf import settings -from docker.errors import ContainerError -from rest_framework import status, mixins -from rest_framework.decorators import action -from rest_framework.exceptions import ValidationError -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet -from rest_framework.reverse import reverse - -from substrabac.celery import app - -from substrapp.models import DataSample, DataManager -from substrapp.serializers import DataSampleSerializer, LedgerDataSampleSerializer -from substrapp.serializers.ledger.datasample.util import updateLedgerDataSample -from substrapp.serializers.ledger.datasample.tasks import updateLedgerDataSampleAsync -from substrapp.utils import uncompress_path, get_dir_hash -from substrapp.tasks import build_subtuple_folders, remove_subtuple_materials -from substrapp.views.utils import find_primary_key_error - -logger = logging.getLogger('django.request') - - -def path_leaf(path): - head, tail = ntpath.split(path) - return tail or ntpath.basename(head) - - -class LedgerException(Exception): - def __init__(self, data, st): - self.data = data - self.st = st - super(LedgerException).__init__() - - -class ValidationException(Exception): - def __init__(self, data, pkhash, st): - self.data = data - self.pkhash = pkhash - self.st = st - super(ValidationException).__init__() - - -@app.task(bind=True, ignore_result=False) -def compute_dryrun(self, data, data_manager_keys): - from shutil import copy - from substrapp.models import DataManager - - client = docker.from_env() - - # Name of the dry-run subtuple (not important) - pkhash = data[0]['pkhash'] - dryrun_uuid = f'{pkhash}_{uuid.uuid4().hex}' - subtuple_directory = build_subtuple_folders({'key': dryrun_uuid}) - data_path = os.path.join(subtuple_directory, 'data') - volumes = {} - - try: - - for data_sample in data: - # uncompress only for file - if 'file' in data_sample: - try: - uncompress_path(data_sample['file'], os.path.join(data_path, data_sample['pkhash'])) - except Exception as e: - raise e - # for all data paths, we need to create symbolic links inside data_path - # and add real path to volume bind docker - elif 'path' in data_sample: - os.symlink(data_sample['path'], os.path.join(data_path, data_sample['pkhash'])) - volumes.update({data_sample['path']: {'bind': data_sample['path'], 'mode': 'ro'}}) - - for datamanager_key in data_manager_keys: - datamanager = DataManager.objects.get(pk=datamanager_key) - copy(datamanager.data_opener.path, os.path.join(subtuple_directory, 'opener/opener.py')) - - # Launch verification - opener_file = os.path.join(subtuple_directory, 'opener/opener.py') - data_sample_docker_path = os.path.join(getattr(settings, 'PROJECT_ROOT'), 'fake_data_sample') # fake_data comes with substrabac - - data_docker = 'data_dry_run' # tag must be lowercase for docker - data_docker_name = f'{data_docker}_{dryrun_uuid}' - - volumes.update({data_path: {'bind': '/sandbox/data', 'mode': 'rw'}, - opener_file: {'bind': '/sandbox/opener/__init__.py', 'mode': 'ro'}}) - - client.images.build(path=data_sample_docker_path, - tag=data_docker, - rm=False) - - job_args = {'image': data_docker, - 'name': data_docker_name, - 'cpuset_cpus': '0-0', - 'mem_limit': '1G', - 'command': None, - 'volumes': volumes, - 'shm_size': '8G', - 'labels': ['dryrun'], - 'detach': False, - 'auto_remove': False, - 'remove': False} - - client.containers.run(**job_args) - - except ContainerError as e: - raise Exception(e.stderr) - finally: - try: - container = client.containers.get(data_docker_name) - container.remove() - except: - logger.error('Could not remove containers') - remove_subtuple_materials(subtuple_directory) - for data_sample in data: - if 'file' in data_sample and os.path.exists(data_sample['file']): - os.remove(data_sample['file']) - - -class DataSampleViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - # mixins.UpdateModelMixin, - # mixins.DestroyModelMixin, - # mixins.ListModelMixin, - GenericViewSet): - queryset = DataSample.objects.all() - serializer_class = DataSampleSerializer - - def dryrun_task(self, data, data_manager_keys): - task = compute_dryrun.apply_async((data, data_manager_keys), - queue=f"{settings.LEDGER['name']}.dryrunner") - current_site = getattr(settings, "DEFAULT_DOMAIN") - task_route = f'{current_site}{reverse("substrapp:task-detail", args=[task.id])}' - return task, f'Your dry-run has been taken in account. You can follow the task execution on {task_route}' - - @staticmethod - def check_datamanagers(data_manager_keys): - datamanager_count = DataManager.objects.filter(pkhash__in=data_manager_keys).count() - - if datamanager_count != len(data_manager_keys): - raise Exception(f'One or more datamanager keys provided do not exist in local substrabac database. Please create them before. DataManager keys: {data_manager_keys}') - - @staticmethod - def commit(serializer, ledger_data): - instances = serializer.save() # can raise - # init ledger serializer - ledger_data.update({'instances': instances}) - ledger_serializer = LedgerDataSampleSerializer(data=ledger_data) - - if not ledger_serializer.is_valid(): - # delete instance - for instance in instances: - instance.delete() - raise ValidationError(ledger_serializer.errors) - - # create on ledger - data, st = ledger_serializer.create(ledger_serializer.validated_data) - - if st == status.HTTP_408_REQUEST_TIMEOUT: - data.update({'pkhash': [x['pkhash'] for x in serializer.data]}) - raise LedgerException(data, st) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED): - raise LedgerException(data, st) - - # update validated to True in response - if 'pkhash' in data and data['validated']: - for d in serializer.data: - if d['pkhash'] in data['pkhash']: - d.update({'validated': data['validated']}) - - return serializer.data, st - - def compute_data(self, request): - data = {} - # files, should be archive - for k, file in request.FILES.items(): - pkhash = get_dir_hash(file) # can raise - # check pkhash does not belong to the list - try: - existing = data[pkhash] - except KeyError: - pass - else: - raise Exception(f'Your data sample archives contain same files leading to same pkhash, please review the content of your achives. Archives {file} and {existing["file"]} are the same') - data[pkhash] = { - 'pkhash': pkhash, - 'file': file - } - - # path/paths case - path = request.POST.get('path', None) - paths = request.POST.getlist('paths', []) - - if path and paths: - raise Exception('Cannot use path and paths together.') - - if path is not None: - paths = [path] - - # paths, should be directories - for path in paths: - if not os.path.isdir(path): - raise Exception(f'One of your paths does not exist, is not a directory or is not an absolute path: {path}') - pkhash = dirhash(path, 'sha256') - try: - existing = data[pkhash] - except KeyError: - pass - else: - # existing can be a dict with a field path or file - raise Exception(f'Your data sample directory contain same files leading to same pkhash. Invalid path: {path}.') - - data[pkhash] = { - 'pkhash': pkhash, - 'path': normpath(path) - } - - if not data: # data empty - raise Exception(f'No data sample provided.') - - return list(data.values()) - - def handle_dryrun(self, data, data_manager_keys): - data_dry_run = [] - - # write uploaded file to disk - for d in data: - pkhash = d['pkhash'] - if 'file' in d: - file_path = os.path.join(getattr(settings, 'DRYRUN_ROOT'), - f'data_{pkhash}.zip') - with open(file_path, 'wb') as f: - f.write(d['file'].open().read()) - - data_dry_run.append({ - 'pkhash': pkhash, - 'file': file_path - }) - - if 'path' in d: - data_dry_run.append(d) - - try: - task, msg = self.dryrun_task(data_dry_run, data_manager_keys) - except Exception as e: - return Exception(f'Could not launch data creation with dry-run on this instance: {str(e)}') - else: - return {'id': task.id, 'message': msg}, status.HTTP_202_ACCEPTED - - def _create(self, request, data_manager_keys, test_only, dryrun): - - if not data_manager_keys: - raise Exception("missing or empty field 'data_manager_keys'") - - self.check_datamanagers(data_manager_keys) # can raise - - computed_data = self.compute_data(request) - - serializer = self.get_serializer(data=computed_data, many=True) - - try: - serializer.is_valid(raise_exception=True) - except Exception as e: - pkhashes = [x['pkhash'] for x in computed_data] - st = status.HTTP_400_BAD_REQUEST - if find_primary_key_error(e): - st = status.HTTP_409_CONFLICT - raise ValidationException(e.args, pkhashes, st) - else: - if dryrun: - return self.handle_dryrun(computed_data, data_manager_keys) - - # create on ledger + db - ledger_data = {'test_only': test_only, - 'data_manager_keys': data_manager_keys} - data, st = self.commit(serializer, ledger_data) - return data, st - - def create(self, request, *args, **kwargs): - dryrun = request.data.get('dryrun', False) - test_only = request.data.get('test_only', False) - data_manager_keys = request.data.getlist('data_manager_keys', []) - - try: - data, st = self._create(request, data_manager_keys, test_only, dryrun) - except ValidationException as e: - return Response({'message': e.data, 'pkhash': e.pkhash}, status=e.st) - except LedgerException as e: - return Response({'message': e.data}, status=e.st) - except Exception as e: - return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) - else: - headers = self.get_success_headers(data) - return Response(data, status=st, headers=headers) - - @action(methods=['post'], detail=False) - def bulk_update(self, request): - - data = request.data - data_manager_keys = data.getlist('data_manager_keys') - data_keys = data.getlist('data_sample_keys') - - args = '"%(hashes)s", "%(dataManagerKeys)s"' % { - 'hashes': ','.join(data_keys), - 'dataManagerKeys': ','.join(data_manager_keys), - } - - if getattr(settings, 'LEDGER_SYNC_ENABLED'): - data, st = updateLedgerDataSample(args, sync=True) - - # patch status for update - if st == status.HTTP_201_CREATED: - st = status.HTTP_200_OK - return Response(data, status=st) - else: - # use a celery task, as we are in an http request transaction - updateLedgerDataSampleAsync.delay(args) - data = { - 'message': 'The substra network has been notified for updating these Data' - } - st = status.HTTP_202_ACCEPTED - return Response(data, status=st) diff --git a/substrabac/substrapp/views/model.py b/substrabac/substrapp/views/model.py deleted file mode 100644 index 34a541eb3..000000000 --- a/substrabac/substrapp/views/model.py +++ /dev/null @@ -1,223 +0,0 @@ -import os -import tempfile - -import requests -from django.http import Http404 -from rest_framework import status, mixins -from rest_framework.decorators import action -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -from substrapp.models import Model -from substrapp.serializers import ModelSerializer - -# from hfc.fabric import Client -# cli = Client(net_profile="../network.json") -from substrapp.utils import queryLedger -from substrapp.views.utils import get_filters, ComputeHashMixin, getObjectFromLedger, CustomFileResponse, JsonException - - -class ModelViewSet(mixins.RetrieveModelMixin, - mixins.ListModelMixin, - ComputeHashMixin, - GenericViewSet): - queryset = Model.objects.all() - serializer_class = ModelSerializer - - # permission_classes = (permissions.IsAuthenticated,) - - def create_or_update_model(self, traintuple, pk): - if traintuple['outModel'] is None: - raise Exception(f'This traintuple related to this model key {pk} does not have a outModel') - - try: - # get objective description from remote node - url = traintuple['outModel']['storageAddress'] - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}) # TODO pass cert - except: - raise Exception(f'Failed to fetch {url}') - else: - if r.status_code != 200: - raise Exception(f'end to end node report {r.text}') - - try: - computed_hash = self.compute_hash(r.content, traintuple['key']) - except Exception: - raise Exception('Failed to fetch outModel file') - else: - if computed_hash != pk: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - f = tempfile.TemporaryFile() - f.write(r.content) - - # save/update objective in local db for later use - instance, created = Model.objects.update_or_create(pkhash=pk, validated=True) - instance.file.save('model', f) - except Exception as e: - raise e - else: - return instance - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - else: - # get instance from remote node - try: - data = getObjectFromLedger(pk, 'queryModelDetails') - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - error = None - instance = None - try: - # try to get it from local db to check if description exists - instance = self.get_object() - except Http404: - try: - instance = self.create_or_update_model(data['traintuple'], - data['traintuple']['outModel']['hash']) - except Exception as e: - error = e - else: - # check if instance has file - if not instance.file: - try: - instance = self.create_or_update_model(data['traintuple'], - data['traintuple']['outModel']['hash']) - except Exception as e: - error = e - finally: - if error is not None: - return Response({'message': str(error)}, status=status.HTTP_400_BAD_REQUEST) - - # do not give access to local files address - if instance is not None: - serializer = self.get_serializer(instance, fields=('owner', 'pkhash', 'creation_date', 'last_modified')) - data.update(serializer.data) - else: - data = {'message': 'Fail to get instance'} - - return Response(data, status=status.HTTP_200_OK) - - def list(self, request, *args, **kwargs): - # can modify result by interrogating `request.version` - - data, st = queryLedger({ - 'args': '{"Args":["queryModels"]}' - }) - algoData = None - objectiveData = None - dataManagerData = None - - # init list to return - if data is None: - data = [] - l = [data] - - if st == 200: - # parse filters - query_params = request.query_params.get('search', None) - - if query_params is not None: - try: - filters = get_filters(query_params) - except Exception as exc: - return Response( - {'message': f'Malformed search filters {query_params}'}, - status=status.HTTP_400_BAD_REQUEST) - else: - # filtering, reinit l to empty array - l = [] - for idx, filter in enumerate(filters): - # init each list iteration to data - if data is None: - data = [] - l.append(data) - for k, subfilters in filter.items(): - if k == 'model': # filter by own key - for key, val in subfilters.items(): - l[idx] = [x for x in l[idx] if x['traintuple']['outModel'] is not None and x['traintuple']['outModel']['hash'] in val] - elif k == 'algo': # select model used by these algo - if not algoData: - # TODO find a way to put this call in cache - algoData, st = queryLedger({ - 'args': '{"Args":["queryAlgos"]}' - }) - if st != status.HTTP_200_OK: - return Response(algoData, status=st) - - if algoData is None: - algoData = [] - for key, val in subfilters.items(): - filteredData = [x for x in algoData if x[key] in val] - algoHashes = [x['key'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['traintuple']['algo']['hash'] in algoHashes] - elif k == 'dataset': # select model which trainData.openerHash is - if not dataManagerData: - # TODO find a way to put this call in cache - dataManagerData, st = queryLedger({ - 'args': '{"Args":["queryDataManagers"]}' - }) - if st != status.HTTP_200_OK: - return Response(dataManagerData, status=st) - - if dataManagerData is None: - dataManagerData = [] - for key, val in subfilters.items(): - filteredData = [x for x in dataManagerData if x[key] in val] - datamanagerHashes = [x['key'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['traintuple']['dataset']['openerHash'] in datamanagerHashes] - elif k == 'objective': # select objective used by these datamanagers - if not objectiveData: - # TODO find a way to put this call in cache - objectiveData, st = queryLedger({ - 'args': '{"Args":["queryObjectives"]}' - }) - if st != status.HTTP_200_OK: - return Response(objectiveData, status=st) - - if objectiveData is None: - objectiveData = [] - for key, val in subfilters.items(): - if key == 'metrics': # specific to nested metrics - filteredData = [x for x in objectiveData if x[key]['name'] in val] - else: - filteredData = [x for x in objectiveData if x[key] in val] - objectiveKeys = [x['key'] for x in filteredData] - l[idx] = [x for x in l[idx] if x['traintuple']['objective']['hash'] in objectiveKeys] - - return Response(l, status=st) - - @action(detail=True) - def file(self, request, *args, **kwargs): - object = self.get_object() - - # TODO query model permissions - - data = getattr(object, 'file') - return CustomFileResponse(open(data.path, 'rb'), as_attachment=True, filename=os.path.basename(data.path)) - - @action(detail=True) - def details(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - data, st = queryLedger({ - 'args': f'{{"Args":["queryModelDetails", "{pk}"]}}' - }) - - return Response(data, st) diff --git a/substrabac/substrapp/views/objective.py b/substrabac/substrapp/views/objective.py deleted file mode 100644 index 345267fc9..000000000 --- a/substrabac/substrapp/views/objective.py +++ /dev/null @@ -1,384 +0,0 @@ -import docker -import logging -import os -import re -import shutil -import tempfile -import uuid - -from urllib.parse import unquote - -import requests -from django.conf import settings -from django.db import IntegrityError -from django.http import Http404 -from django.urls import reverse -from docker.errors import ContainerError -from rest_framework import status, mixins -from rest_framework.decorators import action -from rest_framework.exceptions import ValidationError -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -from substrabac.celery import app - -from substrapp.models import Objective -from substrapp.serializers import ObjectiveSerializer, LedgerObjectiveSerializer - - -from substrapp.utils import queryLedger, get_hash, get_computed_hash -from substrapp.tasks import build_subtuple_folders, remove_subtuple_materials -from substrapp.views.utils import get_filters, getObjectFromLedger, ComputeHashMixin, ManageFileMixin, JsonException, find_primary_key_error - - -@app.task(bind=True, ignore_result=False) -def compute_dryrun(self, metrics_path, test_data_manager_key, pkhash): - - dryrun_uuid = f'{pkhash}_{uuid.uuid4().hex}' - - subtuple_directory = build_subtuple_folders({'key': dryrun_uuid}) - - metrics_path_dst = os.path.join(subtuple_directory, 'metrics/metrics.py') - if not os.path.exists(metrics_path_dst): - shutil.copy2(metrics_path, os.path.join(subtuple_directory, 'metrics/metrics.py')) - os.remove(metrics_path) - - if not test_data_manager_key: - raise Exception('Cannot do a objective dryrun without a data manager key.') - - datamanager = getObjectFromLedger(test_data_manager_key, 'queryDataManager') - opener_content, opener_computed_hash = get_computed_hash(datamanager['opener']['storageAddress']) - with open(os.path.join(subtuple_directory, 'opener/opener.py'), 'wb') as opener_file: - opener_file.write(opener_content) - - # Launch verification - client = docker.from_env() - pred_path = os.path.join(subtuple_directory, 'pred') - opener_file = os.path.join(subtuple_directory, 'opener/opener.py') - metrics_file = os.path.join(subtuple_directory, 'metrics/metrics.py') - metrics_path = os.path.join(getattr(settings, 'PROJECT_ROOT'), 'fake_metrics') # base metrics comes with substrabac - - metrics_docker = 'metrics_dry_run' # tag must be lowercase for docker - metrics_docker_name = f'{metrics_docker}_{dryrun_uuid}' - volumes = {pred_path: {'bind': '/sandbox/pred', 'mode': 'rw'}, - metrics_file: {'bind': '/sandbox/metrics/__init__.py', 'mode': 'ro'}, - opener_file: {'bind': '/sandbox/opener/__init__.py', 'mode': 'ro'}} - - client.images.build(path=metrics_path, - tag=metrics_docker, - rm=False) - - job_args = {'image': metrics_docker, - 'name': metrics_docker_name, - 'cpuset_cpus': '0-0', - 'mem_limit': '1G', - 'command': None, - 'volumes': volumes, - 'shm_size': '8G', - 'labels': ['dryrun'], - 'detach': False, - 'auto_remove': False, - 'remove': False} - - try: - client.containers.run(**job_args) - if not os.path.exists(os.path.join(pred_path, 'perf.json')): - raise Exception('Perf file not found') - - except ContainerError as e: - raise Exception(e.stderr) - - finally: - try: - container = client.containers.get(metrics_docker_name) - container.remove(force=True) - except BaseException as e: - logging.error(e, exc_info=True) - remove_subtuple_materials(subtuple_directory) - - -class ObjectiveViewSet(mixins.CreateModelMixin, - mixins.ListModelMixin, - mixins.RetrieveModelMixin, - ComputeHashMixin, - ManageFileMixin, - GenericViewSet): - queryset = Objective.objects.all() - serializer_class = ObjectiveSerializer - ledger_query_call = 'queryObjective' - # permission_classes = (permissions.IsAuthenticated,) - - def perform_create(self, serializer): - return serializer.save() - - def create(self, request, *args, **kwargs): - """ - Create a new Objective \n - TODO add info about what has to be posted\n - - Example with curl (on localhost): \n - curl -u username:password -H "Content-Type: application/json"\ - -X POST\ - -d '{"name": "tough objective", "permissions": "all", "metrics_name": 'accuracy', "test_data": - ["data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b379", - "data_5c1d9cd1c2c1082dde0921b56d11030c81f62fbb51932758b58ac2569dd0b389"],\ - "files": {"description.md": '#My tough objective',\ - 'metrics.py': 'def AUC_score(y_true, y_pred):\n\treturn 1'}}'\ - http://127.0.0.1:8000/substrapp/objective/ \n - Use double quotes for the json, simple quotes don't work.\n - - Example with the python package requests (on localhost): \n - requests.post('http://127.0.0.1:8000/objective/', - #auth=('username', 'password'), - data={'name': 'MSI classification', 'permissions': 'all', 'metrics_name': 'accuracy', 'test_data_sample_keys': ['da1bb7c31f62244c0f3a761cc168804227115793d01c270021fe3f7935482dcc']}, - files={'description': open('description.md', 'rb'), 'metrics': open('metrics.py', 'rb')}, - headers={'Accept': 'application/json;version=0.0'}) \n - --- - response_serializer: ObjectiveSerializer - """ - - data = request.data - - dryrun = data.get('dryrun', False) - - description = data.get('description') - test_data_manager_key = request.data.get('test_data_manager_key', request.POST.get('test_data_manager_key', '')) - - try: - test_data_sample_keys = request.data.getlist('test_data_sample_keys', []) - except: - test_data_sample_keys = request.data.get('test_data_sample_keys', request.POST.getlist('test_data_sample_keys', [])) - - metrics = data.get('metrics') - - pkhash = get_hash(description) - serializer = self.get_serializer(data={'pkhash': pkhash, - 'metrics': metrics, - 'description': description}) - - try: - serializer.is_valid(raise_exception=True) - except ValidationError as e: - st = status.HTTP_400_BAD_REQUEST - if find_primary_key_error(e): - st = status.HTTP_409_CONFLICT - return Response({'message': e.args, 'pkhash': pkhash}, status=st) - - if dryrun: - try: - metrics_path = os.path.join(getattr(settings, 'DRYRUN_ROOT'), f'metrics_{pkhash}.py') - with open(metrics_path, 'wb') as metrics_file: - metrics_file.write(metrics.open().read()) - - task = compute_dryrun.apply_async((metrics_path, test_data_manager_key, pkhash), queue=f"{settings.LEDGER['name']}.dryrunner") - except Exception as e: - return Response({'message': f'Could not launch objective creation with dry-run on this instance: {str(e)}'}, - status=status.HTTP_400_BAD_REQUEST) - - current_site = getattr(settings, "DEFAULT_DOMAIN") - task_route = f'{current_site}{reverse("substrapp:task-detail", args=[task.id])}' - msg = f'Your dry-run has been taken in account. You can follow the task execution on {task_route}' - - return Response({'id': task.id, 'message': msg}, status=status.HTTP_202_ACCEPTED) - - # create on db - try: - instance = self.perform_create(serializer) - except IntegrityError as exc: - try: - pkhash = re.search(r'\(pkhash\)=\((\w+)\)', exc.args[0]).group(1) - except BaseException: - pkhash = '' - finally: - return Response({'message': 'A objective with this description file already exists.', 'pkhash': pkhash}, - status=status.HTTP_409_CONFLICT) - except Exception as exc: - return Response({'message': exc.args}, - status=status.HTTP_400_BAD_REQUEST) - - # init ledger serializer - ledger_serializer = LedgerObjectiveSerializer(data={'test_data_sample_keys': test_data_sample_keys, - 'test_data_manager_key': test_data_manager_key, - 'name': data.get('name'), - 'permissions': data.get('permissions'), - 'metrics_name': data.get('metrics_name'), - 'instance': instance}, - context={'request': request}) - - if not ledger_serializer.is_valid(): - # delete instance - instance.delete() - raise ValidationError(ledger_serializer.errors) - - # create on ledger - data, st = ledger_serializer.create(ledger_serializer.validated_data) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED, status.HTTP_408_REQUEST_TIMEOUT): - return Response(data, status=st) - - headers = self.get_success_headers(serializer.data) - d = dict(serializer.data) - d.update(data) - return Response(d, status=st, headers=headers) - - def create_or_update_objective(self, objective, pk): - # get objective description from remote node - url = objective['description']['storageAddress'] - try: - r = requests.get(url, headers={'Accept': 'application/json;version=0.0'}) # TODO pass cert - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): - raise Exception(f'Failed to fetch {url}') - if r.status_code != status.HTTP_200_OK: - raise Exception(f'end to end node report {r.text}') - - try: - computed_hash = self.compute_hash(r.content) - except Exception: - raise Exception('Failed to fetch description file') - - if computed_hash != pk: - msg = 'computed hash is not the same as the hosted file. Please investigate for default of synchronization, corruption, or hacked' - raise Exception(msg) - - f = tempfile.TemporaryFile() - f.write(r.content) - - # save/update objective in local db for later use - instance, created = Objective.objects.update_or_create(pkhash=pk, validated=True) - instance.description.save('description.md', f) - - return instance - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except ValueError: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - # get instance from remote node - try: - data = getObjectFromLedger(pk, self.ledger_query_call) - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - # try to get it from local db to check if description exists - try: - instance = self.get_object() - except Http404: - instance = None - - if not instance or not instance.description: - try: - instance = self.create_or_update_objective(data, pk) - except Exception as e: - return Response({'message': str(e)}, status=status.HTTP_400_BAD_REQUEST) - - # do not give access to local files address - serializer = self.get_serializer( - instance, fields=('owner', 'pkhash', 'creation_date', 'last_modified')) - data.update(serializer.data) - return Response(data, status=status.HTTP_200_OK) - - def list(self, request, *args, **kwargs): - # can modify result by interrogating `request.version` - - data, st = queryLedger({ - 'args': '{"Args":["queryObjectives"]}' - }) - - data = [] if data is None else data - objectives = [data] - - if st != status.HTTP_200_OK: - return Response(objectives, status=st) - - dataManagerData = None - algoData = None - modelData = None - - # parse filters - query_params = request.query_params.get('search', None) - if query_params is None: - return Response(objectives, status=st) - - try: - filters = get_filters(query_params) - except Exception: - return Response( - {'message': f'Malformed search filters {query_params}'}, - status=status.HTTP_400_BAD_REQUEST) - - # filtering - objectives = [] - for idx, filter in enumerate(filters): - # init each list iteration to data - objectives.append(data) - - for k, subfilters in filter.items(): - if k == 'objective': # filter by own key - for key, val in subfilters.items(): - if key == 'metrics': # specific to nested metrics - objectives[idx] = [x for x in objectives[idx] if x[key]['name'] in val] - else: - objectives[idx] = [x for x in objectives[idx] if x[key] in val] - - elif k == 'dataset': # select objective used by these datamanagers - if not dataManagerData: - # TODO find a way to put this call in cache - dataManagerData, st = queryLedger({ - 'args': '{"Args":["queryDataManagers"]}' - }) - if st != status.HTTP_200_OK: - return Response(dataManagerData, status=st) - if dataManagerData is None: - dataManagerData = [] - - for key, val in subfilters.items(): - filteredData = [x for x in dataManagerData if x[key] in val] - dataManagerKeys = [x['key'] for x in filteredData] - objectiveKeys = [x['objectiveKey'] for x in filteredData] - objectives[idx] = [x for x in objectives[idx] if x['key'] in objectiveKeys or - (x['testDataset'] and x['testDataset']['dataManagerKey'] in dataManagerKeys)] - - elif k == 'model': # select objectives used by outModel hash - if not modelData: - # TODO find a way to put this call in cache - modelData, st = queryLedger({ - 'args': '{"Args":["queryTraintuples"]}' - }) - if st != status.HTTP_200_OK: - return Response(modelData, status=st) - if modelData is None: - modelData = [] - - for key, val in subfilters.items(): - filteredData = [x for x in modelData if x['outModel'] is not None and x['outModel'][key] in val] - objectiveKeys = [x['objective']['hash'] for x in filteredData] - objectives[idx] = [x for x in objectives[idx] if x['key'] in objectiveKeys] - - return Response(objectives, status=st) - - @action(detail=True) - def description(self, request, *args, **kwargs): - return self.manage_file('description') - - @action(detail=True) - def metrics(self, request, *args, **kwargs): - return self.manage_file('metrics') - - @action(detail=True) - def data(self, request, *args, **kwargs): - instance = self.get_object() - - # TODO fetch list of data from ledger - # query list of related algos and models from ledger - - # return success and model - - serializer = self.get_serializer(instance) - return Response(serializer.data) diff --git a/substrabac/substrapp/views/testtuple.py b/substrabac/substrapp/views/testtuple.py deleted file mode 100644 index ae25f0f0f..000000000 --- a/substrabac/substrapp/views/testtuple.py +++ /dev/null @@ -1,143 +0,0 @@ -import json - -from django.http import Http404 -from rest_framework import mixins, status -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -from substrapp.serializers import LedgerTestTupleSerializer -from substrapp.utils import queryLedger -from substrapp.views.utils import getObjectFromLedger, JsonException - - -class TestTupleViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - GenericViewSet): - serializer_class = LedgerTestTupleSerializer - - def get_queryset(self): - queryset = [] - return queryset - - def perform_create(self, serializer): - return serializer.save() - - def create(self, request, *args, **kwargs): - # TODO update - ''' - curl -H "Accept: text/html;version=0.0, */*;version=0.0" - -d "algo_key=da58a7a29b549f2fe5f009fb51cce6b28ca184ec641a0c1db075729bb266549b&model_key=10060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568&train_data_keys[]=62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a&train_data[]=42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9" - -X POST http://localhost:8001/traintuple/ - - or - - curl -H "Accept: text/html;version=0.0, */*;version=0.0" - -H "Content-Type: application/json" - -d '{"algo_key":"da58a7a29b549f2fe5f009fb51cce6b28ca184ec641a0c1db075729bb266549b","model_key":"10060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568","train_data_keys":["62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a","42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9"]}' - -X POST http://localhost:8001/traintuple/?format=json - - :param request: - :return: - ''' - - traintuple_key = request.data.get('traintuple_key', request.POST.get('traintuple_key', None)) - data_manager_key = request.data.get('data_manager_key', request.POST.get('data_manager_key', '')) - tag = request.data.get('tag', request.POST.get('tag', '')) - - try: - test_data_sample_keys = request.data.getlist('test_data_sample_keys', []) - except: - test_data_sample_keys = request.data.get('test_data_sample_keys', request.POST.getlist('test_data_sample_keys', [])) - - data = { - 'traintuple_key': traintuple_key, - 'data_manager_key': data_manager_key, - 'test_data_sample_keys': test_data_sample_keys, # list of test data keys - 'tag': tag - } - - # init ledger serializer - serializer = self.get_serializer(data=data) - serializer.is_valid(raise_exception=True) - - # Get testtuple pkhash of the proposal with a queryLedger in case of 408 timeout - args = serializer.get_args(serializer.validated_data) - data, st = queryLedger({'args': '{"Args":["createTesttuple", ' + args + ']}'}) - if st == status.HTTP_200_OK: - pkhash = data.get('key', data.get('keys')) - else: - # If queryLedger fails, invoke will fail too so we handle the issue right now - try: - data['message'] = data['message'].split('Error')[-1] - msg = json.loads(data['message'].split('payload:')[-1].strip().strip('"').encode('utf-8').decode('unicode_escape')) - pkhash = msg['error'].replace('(', '').replace(')', '').split('tkey: ')[-1].strip() - - if len(pkhash) != 64: - raise Exception('bad pkhash') - else: - st = status.HTTP_409_CONFLICT - - return Response({'message': data['message'].split('payload')[0], - 'pkhash': pkhash}, status=st) - except: - return Response(data, status=st) - - # create on ledger - data, st = serializer.create(serializer.validated_data) - - if st == status.HTTP_408_REQUEST_TIMEOUT: - return Response({'message': data['message'], - 'pkhash': pkhash}, status=st) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED): - try: - data['message'] = data['message'].split('Error')[-1] - msg = json.loads(data['message'].split('payload:')[-1].strip().strip('"').encode('utf-8').decode('unicode_escape')) - pkhash = msg['error'].replace('(', '').replace(')', '').split('tkey: ')[-1].strip() - - if len(pkhash) != 64: - raise Exception('bad pkhash') - else: - st = status.HTTP_409_CONFLICT - - return Response({'message': data['message'].split('payload')[0], - 'pkhash': pkhash}, status=st) - except: - return Response(data, status=st) - - headers = self.get_success_headers(serializer.data) - return Response(data, status=st, headers=headers) - - def list(self, request, *args, **kwargs): - # can modify result by interrogating `request.version` - - data, st = queryLedger({ - 'args': '{"Args":["queryTesttuples"]}' - }) - - data = data if data else [] - - return Response(data, status=st) - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - else: - # get instance from remote node - try: - data = getObjectFromLedger(pk, 'queryTesttuple') - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - return Response(data, status=status.HTTP_200_OK) diff --git a/substrabac/substrapp/views/traintuple.py b/substrabac/substrapp/views/traintuple.py deleted file mode 100644 index 2ad5d213d..000000000 --- a/substrabac/substrapp/views/traintuple.py +++ /dev/null @@ -1,170 +0,0 @@ -import json - -from django.http import Http404 -from rest_framework import mixins, status -from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet - -from substrapp.serializers import LedgerTrainTupleSerializer -from substrapp.utils import queryLedger -from substrapp.views.utils import JsonException - - -class TrainTupleViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - GenericViewSet): - serializer_class = LedgerTrainTupleSerializer - - def get_queryset(self): - queryset = [] - return queryset - - def perform_create(self, serializer): - return serializer.save() - - def create(self, request, *args, **kwargs): - # TODO update - ''' - curl -H "Accept: text/html;version=0.0, */*;version=0.0" - -d "algo_key=da58a7a29b549f2fe5f009fb51cce6b28ca184ec641a0c1db075729bb266549b&model_key=10060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568&train_data_sample_keys[]=62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a&train_data[]=42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9" - -X POST http://localhost:8001/traintuple/ - - or - - curl -H "Accept: text/html;version=0.0, */*;version=0.0" - -H "Content-Type: application/json" - -d '{"algo_key":"da58a7a29b549f2fe5f009fb51cce6b28ca184ec641a0c1db075729bb266549b","model_key":"10060f1d9e450d98bb5892190860eee8dd48594f00e0e1c9374a27c5acdba568","train_data_sample_keys":["62fb3263208d62c7235a046ee1d80e25512fe782254b730a9e566276b8c0ef3a","42303efa663015e729159833a12ffb510ff92a6e386b8152f90f6fb14ddc94c9"]}' - -X POST http://localhost:8001/traintuple/?format=json - - :param request: - :return: - ''' - - algo_key = request.data.get('algo_key', request.POST.get('algo_key', None)) - data_manager_key = request.data.get('data_manager_key', request.POST.get('data_manager_key', None)) - objective_key = request.data.get('objective_key', request.POST.get('objective_key', None)) - rank = request.data.get('rank', request.POST.get('rank', None)) - FLtask_key = request.data.get('FLtask_key', request.POST.get('FLtask_key', '')) - tag = request.data.get('tag', request.POST.get('tag', '')) - - try: - in_models_keys = request.data.getlist('in_models_keys', []) - except: - in_models_keys = request.data.get('in_models_keys', request.POST.getlist('in_models_keys', [])) - - try: - train_data_sample_keys = request.data.getlist('train_data_sample_keys', []) - except: - train_data_sample_keys = request.data.get('train_data_sample_keys', request.POST.getlist('train_data_sample_keys', [])) - - data = { - 'algo_key': algo_key, - 'data_manager_key': data_manager_key, - 'objective_key': objective_key, - 'rank': rank, - 'FLtask_key': FLtask_key, - 'in_models_keys': in_models_keys, - 'train_data_sample_keys': train_data_sample_keys, # list of train data keys (which are stored in the train worker node) - 'tag': tag - } - - # init ledger serializer - serializer = self.get_serializer(data=data) - serializer.is_valid(raise_exception=True) - - # Get traintuple pkhash of the proposal with a queryLedger in case of 408 timeout - args = serializer.get_args(serializer.validated_data) - data, st = queryLedger({'args': '{"Args":["createTraintuple", ' + args + ']}'}) - if st == status.HTTP_200_OK: - pkhash = data.get('key', data.get('keys')) - else: - # If queryLedger fails, invoke will fail too so we handle the issue right now - try: - data['message'] = data['message'].split('Error')[-1] - msg = json.loads(data['message'].split('payload:')[-1].strip().strip('"').encode('utf-8').decode('unicode_escape')) - pkhash = msg['error'].replace('(', '').replace(')', '').split('tkey: ')[-1].strip() - - if len(pkhash) != 64: - raise Exception('bad pkhash') - else: - st = status.HTTP_409_CONFLICT - - return Response({'message': data['message'].split('payload')[0], - 'pkhash': pkhash}, status=st) - except: - return Response(data, status=st) - - # create on ledger - data, st = serializer.create(serializer.validated_data) - - if st == status.HTTP_408_REQUEST_TIMEOUT: - return Response({'message': data['message'], - 'pkhash': pkhash}, status=st) - - if st not in (status.HTTP_201_CREATED, status.HTTP_202_ACCEPTED): - try: - data['message'] = data['message'].split('Error')[-1] - msg = json.loads(data['message'].split('payload:')[-1].strip().strip('"').encode('utf-8').decode('unicode_escape')) - pkhash = msg['error'].replace('(', '').replace(')', '').split('tkey: ')[-1].strip() - - if len(pkhash) != 64: - raise Exception('bad pkhash') - else: - st = status.HTTP_409_CONFLICT - - return Response({'message': data['message'].split('payload')[0], - 'pkhash': pkhash}, status=st) - except: - return Response(data, status=st) - - headers = self.get_success_headers(serializer.data) - return Response(data, status=st, headers=headers) - - def list(self, request, *args, **kwargs): - data, st = queryLedger({ - 'args': '{"Args":["queryTraintuples"]}' - }) - - data = data if data else [] - - return Response(data, status=st) - - def getObjectFromLedger(self, pk): - # get instance from remote node - data, st = queryLedger({ - 'args': f'{{"Args":["queryTraintuple","{pk}"]}}' - }) - - if st == status.HTTP_404_NOT_FOUND: - raise Http404('Not found') - - if st != status.HTTP_200_OK: - raise JsonException(data) - - if 'permissions' not in data or data['permissions'] == 'all': - return data - else: - raise Exception('Not Allowed') - - def retrieve(self, request, *args, **kwargs): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - if len(pk) != 64: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - - try: - int(pk, 16) # test if pk is correct (hexadecimal) - except: - return Response({'message': f'Wrong pk {pk}'}, status.HTTP_400_BAD_REQUEST) - else: - # get instance from remote node - try: - data = self.getObjectFromLedger(pk) - except JsonException as e: - return Response(e.msg, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - return Response(data, status=status.HTTP_200_OK) diff --git a/substrabac/substrapp/views/utils.py b/substrabac/substrapp/views/utils.py deleted file mode 100644 index 182058283..000000000 --- a/substrabac/substrapp/views/utils.py +++ /dev/null @@ -1,144 +0,0 @@ -import hashlib -import os -from urllib.parse import unquote - -from django.http import FileResponse, Http404 -from rest_framework import status -from rest_framework.response import Response - -from substrapp.utils import queryLedger - - -class JsonException(Exception): - def __init__(self, msg): - self.msg = msg - super(JsonException, self).__init__() - - -def get_filters(query_params): - filters = [] - groups = query_params.split('-OR-') - for idx, group in enumerate(groups): - - # init - filters.append({}) - - # get number of subfilters and decode them - subfilters = [unquote(x) for x in group.split(',')] - - for subfilter in subfilters: - el = subfilter.split(':') - - # get parent - parent = el[0] - subparent = el[1] - value = el[2] - - filter = { - subparent: [unquote(value)] - } - - if not len(filters[idx]): # create and add it - filters[idx] = { - parent: filter - } - else: # add it - if parent in filters[idx]: # add - if el[1] in filters[idx][parent]: # concat in subparent - filters[idx][parent][subparent].extend([value]) - else: # add new subparent - filters[idx][parent].update(filter) - else: # create - filters[idx].update({parent: filter}) - - return filters - - -def getObjectFromLedger(pk, query): - # get instance from remote node - data, st = queryLedger({ - 'args': f'{{"Args":["{query}","{pk}"]}}' - }) - - if st == status.HTTP_404_NOT_FOUND: - raise Http404('Not found') - - if st != status.HTTP_200_OK: - raise JsonException(data) - - if 'permissions' not in data or data['permissions'] == 'all': - return data - else: - raise Exception('Not Allowed') - - -class ComputeHashMixin(object): - def compute_hash(self, file, key=None): - - sha256_hash = hashlib.sha256() - if isinstance(file, str): - file = file.encode() - - if key is not None and isinstance(key, str): - file += key.encode() - - sha256_hash.update(file) - computedHash = sha256_hash.hexdigest() - - return computedHash - - -class CustomFileResponse(FileResponse): - def set_headers(self, filelike): - super(CustomFileResponse, self).set_headers(filelike) - - self['Access-Control-Expose-Headers'] = 'Content-Disposition' - - -class ManageFileMixin(object): - def manage_file(self, field): - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - pk = self.kwargs[lookup_url_kwarg] - - # TODO get cert for permissions check - - try: - getObjectFromLedger(pk, self.ledger_query_call) - except Exception as e: - return Response(e, status=status.HTTP_400_BAD_REQUEST) - except Http404: - return Response(f'No element with key {pk}', status=status.HTTP_404_NOT_FOUND) - else: - object = self.get_object() - - data = getattr(object, field) - return CustomFileResponse(open(data.path, 'rb'), as_attachment=True, filename=os.path.basename(data.path)) - - -def find_primary_key_error(validation_error, key_name='pkhash'): - detail = validation_error.detail - - def find_unique_error(detail_dict): - for key, errors in detail_dict.items(): - if key != key_name: - continue - for error in errors: - if error.code == 'unique': - return error - - return None - - # according to the rest_framework documentation, - # validation_error.detail could be either a dict, a list or a nested - # data structure - - if isinstance(detail, dict): - return find_unique_error(detail) - elif isinstance(detail, list): - for sub_detail in detail: - if isinstance(sub_detail, dict): - unique_error = find_unique_error(sub_detail) - if unique_error is not None: - return unique_error - - return None