From a6192e930cafcf828655b9370268d5de05b8084b Mon Sep 17 00:00:00 2001 From: alexioannides Date: Tue, 20 Jul 2021 22:59:59 +0100 Subject: [PATCH] Finalise stages and tests --- notebooks/requirements_nb.txt | 2 +- notebooks/time_to_dispatch_model.ipynb | 2 +- pipeline/serve_model.py | 36 ++++++++++++++++++--- pipeline/train_model.py | 17 +++++----- pipeline/utils.py | 18 ----------- requirements_pipe.txt | 11 +++---- tests/resources/model.pkl | Bin 0 -> 24069 bytes tests/test_serve_model.py | 43 +++++++++++++++++++++++-- tests/test_train_model.py | 25 ++++++++------ 9 files changed, 103 insertions(+), 51 deletions(-) delete mode 100644 pipeline/utils.py create mode 100644 tests/resources/model.pkl diff --git a/notebooks/requirements_nb.txt b/notebooks/requirements_nb.txt index c01baef..3117a56 100644 --- a/notebooks/requirements_nb.txt +++ b/notebooks/requirements_nb.txt @@ -1,7 +1,7 @@ jupyterlab==3.0.16 seaborn==0.11.1 numpy==1.21.0 -pandas==1.2.5 +pandas==1.3.0 scikit-learn==0.24.2 boto3==1.17.101 joblib==1.0.1 diff --git a/notebooks/time_to_dispatch_model.ipynb b/notebooks/time_to_dispatch_model.ipynb index f227e2e..32cc41b 100644 --- a/notebooks/time_to_dispatch_model.ipynb +++ b/notebooks/time_to_dispatch_model.ipynb @@ -344,7 +344,7 @@ "def preprocess(df: pd.DataFrame) -> np.ndarray:\n", " df_processed = df.copy()\n", " category_map = {\"SKU001\": 0, \"SKU002\": 1, \"SKU003\": 2, \"SKU004\": 3, \"SKU005\": 4}\n", - " df_processed[\"product_code\"] = df[\"product_code\"].apply(lambda e: category_map[e])\n", + " df_processed[\"product_code\"] = df[\"product_code\"].apply(lambda e: PRODUCT_CODE_MAP[e])\n", " return df_processed.values\n", "\n", "preprocess(dataset)" diff --git a/pipeline/serve_model.py b/pipeline/serve_model.py index 9bca1c1..c5ed454 100644 --- a/pipeline/serve_model.py +++ b/pipeline/serve_model.py @@ -2,18 +2,33 @@ - Get model and load into memory. - Start web API server. """ +import sys +from enum import Enum from typing import Dict, Union import uvicorn +from bodywork_pipeline_utils import aws, logging from fastapi import FastAPI, status -from pydantic import BaseModel +from numpy import array +from pydantic import BaseModel, Field + +from pipeline.train_model import PRODUCT_CODE_MAP app = FastAPI(debug=False) +log = logging.configure_logger() + + +class ProductCode(Enum): + SKU001 = "SKU001" + SKU002 = "SKU002" + SKU003 = "SKU003" + SKU004 = "SKU004" + SKU005 = "SKU005" class Data(BaseModel): - product_code: str - orders_placed: float + product_code: ProductCode + orders_placed: float = Field(..., ge=0.0) class Prediction(BaseModel): @@ -27,8 +42,21 @@ class Prediction(BaseModel): response_model=Prediction, ) def time_to_dispatch(data: Data) -> Dict[str, Union[str, float]]: - return {"est_hours_to_dispatch": 1.0, "model_version": "0.0.1"} + features = array([[data.orders_placed, PRODUCT_CODE_MAP[data.product_code.value]]]) + prediction = wrapped_model.model.predict(features).tolist()[0] + return {"est_hours_to_dispatch": prediction, "model_version": str(wrapped_model)} if __name__ == "__main__": + try: + args = sys.argv + s3_bucket = args[1] + wrapped_model = aws.get_latest_pkl_model_from_s3(s3_bucket, "models") + log.info(f"Successfully loaded model: {wrapped_model}") + except IndexError: + log.error("Invalid arguments passed to serve_model.py - expected S3_BUCKET") + sys.exit(1) + except Exception as e: + log.error(f"Could not get latest model and start web server - {e}") + sys.exit(1) uvicorn.run(app, host="0.0.0.0", workers=1) diff --git a/pipeline/train_model.py b/pipeline/train_model.py index f9f7426..3faec02 100644 --- a/pipeline/train_model.py +++ b/pipeline/train_model.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, NamedTuple, Tuple from bodywork_pipeline_utils import aws, logging -from bodywork_pipeline_utils.aws.datasets import Dataset +from bodywork_pipeline_utils.aws import Dataset from numpy import array, ndarray from pandas import DataFrame from sklearn.base import BaseEstimator @@ -15,7 +15,7 @@ from sklearn.metrics import mean_absolute_error, r2_score from sklearn.tree import DecisionTreeRegressor -CATEGORY_MAP = {"SKU001": 0, "SKU002": 1, "SKU003": 2, "SKU004": 3, "SKU005": 4} +PRODUCT_CODE_MAP = {"SKU001": 0, "SKU002": 1, "SKU003": 2, "SKU004": 3, "SKU005": 4} HYPERPARAM_GRID = { "random_state": [42], "criterion": ["mse", "mae"], @@ -115,8 +115,8 @@ def verify_trained_model_logic(model: BaseEstimator, data: FeatureAndLabels) -> issues_detected: List[str] = [] orders_placed_sensitivity_checks = [ - model.predict(array([[100, product], [110, product]])).tolist() - for product in range(len(CATEGORY_MAP)) + model.predict(array([[100, product], [150, product]])).tolist() + for product in range(len(PRODUCT_CODE_MAP)) ] if not all(e[0] < e[1] for e in orders_placed_sensitivity_checks): issues_detected.append( @@ -140,9 +140,9 @@ def verify_trained_model_logic(model: BaseEstimator, data: FeatureAndLabels) -> def preprocess(df: DataFrame) -> DataFrame: """Create features for training model.""" - df_processed = df.copy() - df_processed["product_code"] = df["product_code"].apply(lambda e: CATEGORY_MAP[e]) - return df_processed.values + processed = df.copy() + processed["product_code"] = df["product_code"].apply(lambda e: PRODUCT_CODE_MAP[e]) + return processed.values def persist_model( @@ -152,7 +152,6 @@ def persist_model( metadata = { "r_squared": metrics.r_squared, "mean_absolute_error": metrics.mean_absolute_error, - "category_map": CATEGORY_MAP } wrapped_model = aws.Model("time-to-dispatch", model, dataset, metadata) s3_location = wrapped_model.put_model_to_s3(bucket, "models") @@ -169,13 +168,13 @@ def persist_model( r2_metric_warning_threshold = float(args[3]) if r2_metric_warning_threshold <= 0 or r2_metric_warning_threshold > 1: raise ValueError() - except (ValueError, IndexError): log.error( "Invalid arguments passed to train_model.py. " "Expected S3_BUCKET R_SQUARED_ERROR_THRESHOLD R_SQUARED_WARNING_THRESHOLD, " "where all thresholds must be in the range [0, 1]." ) + sys.exit(1) try: main( diff --git a/pipeline/utils.py b/pipeline/utils.py deleted file mode 100644 index f6df40d..0000000 --- a/pipeline/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Utility functions that are common between stages. -""" -import sys -import logging - - -def configure_logger() -> logging.Logger: - """Configure a logger that will write to stdout.""" - log_handler = logging.StreamHandler(sys.stdout) - log_format = logging.Formatter( - "%(asctime)s - %(levelname)s - %(module)s.%(funcName)s - %(message)s" - ) - log_handler.setFormatter(log_format) - log = logging.getLogger(__name__) - log.addHandler(log_handler) - log.setLevel(logging.INFO) - return log diff --git a/requirements_pipe.txt b/requirements_pipe.txt index 77f1030..598b11a 100644 --- a/requirements_pipe.txt +++ b/requirements_pipe.txt @@ -1,8 +1,7 @@ -numpy>=1.21.0 -pandas>=1.2.0 -scikit-learn>=0.24.0 -boto3>=1.17.0 -joblib>=1.0.0 +numpy==1.21.0 +pandas==1.2.5 +scikit-learn==0.24.2 +boto3==1.17.101 fastapi==0.65.2 uvicorn==0.14.0 -git+https://github.com/bodywork-ml/bodywork-pipeline-utils@v0.1.1 +git+https://github.com/bodywork-ml/bodywork-pipeline-utils@v0.1.4 diff --git a/tests/resources/model.pkl b/tests/resources/model.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7b8ab02daf1356135a64682efe194c00030bdb01 GIT binary patch literal 24069 zcmbVUd0b52`@b!;qLql6_OzFFaUUX-5=EAnrkWPo)RZlZCF>xJeapU+H6MH!LdcRW zWE+Z#D2k%+yESKg=Q{I!-QRrsqvt&5JkLGv^E_vH&bfCEzulT266%LiD}%V#1t+&lb4v3;l!De;gp;j%T3A<3fcv?CWNbnJ?98#3T*}L*eRT3u27)E zOHAfE@=_gR6Eo5{yqE-`K#$E!=Om`EV>vud2A9X4#GNJ-xFRmYxmAJ9OJz$H*lsRv zu8uBVjxMg@uAbTM9@)LTonkVwgaZA)9+bezpdN2-?&aabjfsinxW{m#y*MsjRMOLv z6YI{6qmpsaG3G*n7P}2NTcW~>cK7yliSzJ`QWBSwp5nwy=W?CcF-e?^3@&8=!vWlw#Eit$lyEA41UEjNn~{;4{vU&A z#H1(kxapLDP@s~WL77H9BP}VBN973xs?ppGo=~8X%$drL<)-lxghA>8-DJv08Jy&_ zByI+~O;#ABEYOk7rI5t6W*W6-PT?lTC-B&D>6{oIRSmnXaFE1_1R>CFeRv#~!^@I0!K!*W6B>eY^s*lR%pp&wep%N?9@zNS|%@pDpnG7 zkZnMlLc3NUu&EwJ_epCmWz$wiQ1+)Dq?eMJoHorVCN-VwL?tON#I$ySxb)QI=*+k{ z$^)x~wrkaqsgw|(psM-P)_s~E?%@yI6`O}q#{uA6NZF@1pcP}+WrL* zJYjT@QjSod^dK$k8zXA%2K{N z7-&shb$-1Ds7Qgemb;{ZDw69VEB_~L4UZFwDmonOr_N{LzzuOg3Tbk8IZK1VO$o~C83=$K~=Dv8Qbo&}9ZI2}1h*8fX zH!azE-33ff3q6>Q17#96pyhG5*s$&qDfO46M2E4y!cFVgY6G8tDutV>` zfxDPqM+B*p6dH{-kqxKzEE8#$kp{GUcgFgp;(z0m`cqdbK%?Ej?!?XG$!DLC(tao@ zc4MrsqU^-kE7oNsy+2tp>glx`b$1yTCA)rRC*#~s)k zD@bcveg2H~%^LarMLn;aG@q_yRU>RZg7UD0#8{}-GkA^&{1@b)$4_ijku zn1(~j^n|*GrrqVrVv^o}ofvV78WKi@D?d|so+!&|-f}LWoRrSb5+v%3^|9i*8z0hp zgy~7`YlE6Z_n#^V^ImK`y|$5bkT$uE27oEz>3&6-{9aI2`&2K$=?5vbx1_`mSV+VI zQ5xa90`K3_F$t;PNvXdi1*S={w^En@ z?uaK5rR#}FcTKBl;tS`?BNx==<0@#$->OnR{2)yeU(7Jis02{?q#&#`+6W zeTn_Ih0Z?n1*@+&YFA%*?Sp;*UCIhL`X8|I?uX?2$khY5&MKbWd7l`oj|pd9|2`}F zKe(@O|AbV*5$|oU6}F3aY|mrWN?)=10JUqNy!P>XL%vSf`t-1`SlrVO$@iD5=Lu7* z0(zvr{G@O^LAYj(JbkVRYo9-=XMnu=2yYNP?qcB_`bFV+$$OLY$cO=-vHF<&$1;5~ z)>l-E%}W6Y-w=5?p+AKF+#LLkh3!(|a8tSYVGGvJf>66g%WEGspRp`TOw(I_ky1OfiHU~8P$VBF zSI-;rFRoV6`1YBktyla(fQ~7MP^3qnceO!uVCdTuiz=~rDFn$6lUtt`>}K{3t}}{!@jtG0O0GN*xx@GvVZS zC{OZfxTUaNN>*@|^}6txl7+mKy#pd9guKVQe_K7 z;~)X?lV$UzX$XX<(&qScX0@a={%b4Z2SXx~nJAT&rF?(j3v8ycqdt&kw0vOfXZ-qx z!Kcj1FuP1b&q;3+CALndBYe~T3r;9{-V{5bl9ZmaNyL011IcG0N+Y7706X^xiC(X< zc!GxnGieHl5d{JCD)&kCRD8W9hHK3B{(Vki|FPbO-%g1s$MmKkNK@tF3SvHqupm7CA?Cm72;WS(IEMm%&&53hjy%I~&Oq|n z^7M!xU{#+Ho3hKX{U0OFLvPebRotFTdaQya&l=-iWAWfDgm1PyoDc%URol#oO8EH% zQ+)(H-Ml7dj>3K*25aon+n}hHP5XHv(DHox@shO0Xe@MgvVJ?!-L6kla}-8TEEeoj)P+ zIW``aAbd;Z;fxvy#00-3dYQN$V|_CoIVbwRP#kBa`=CYC`_#YBp_d_iLU}k@VZge5 z-0Np*J;|b#11v}UN?JZ*sD8$GyLx+jb2&E8uRtQX^7M$|04z79J3PhrjYQ@eR@{(R zm>-zzDN3_knP2}2v*#*=Z?!y}t^JJ1Z#T#53AVm7)kkFPNlS2krttXV?{y7a>2MFj zxd!1|FAt}9B=E=QIOz8NKBiD_bs(g-wXxCz~%#ivl6H z`}AK$ztw;i5A(;uW+by6QTjZu3ZeadeBX(bz(@M5R>~g=TM(azD2?!kK-3c3v2)9+ zFh6WXg4<{ch!G70zOB%1t-Ip(6zRQG5m(~-6(&FMpE@V`jeV-{ybBlm5AY5Cj`?o~ zYRYc8?HdHFt3SSW%4;U=Wy|>kVJG5uAxifHb)GfY^xP)xdTjr|)IKP`T;vyCiLFOW zdg9sZZ`U7bz~;d{2%oh7{f(1w^?Oql`di_1;k;QLIg5lX3j2ZY&ox=dt;77VAK}cG zYfn`mHopo^xuguPsJxBv2g3oxAC}FRrePSskuF7iVGXu_IEZ8pNo8fJC>U5<`uGtB zxILNtz$$%tr($4}!ryto*VS&;)V}QRb@K?qbyOCgG%Xqig!=%)>arVHe9eTD$VyQ@ z?08<`^?^7MdwgHR4Q%~8hA^E(^zVKqV!xO%QR_Rl-#m`wPe|J&OM%+QIRzwXUizZ& zcmTa=%G&GSV(S=_UHG4yN_@|~#pc0N2-j&@e9|1*#Q#gnM@(KA@Q?LyI;~Vk8q&%=;2h%5)AA9+qT-&V+lA9JKPWuE5St(0 zy0-HLHoq|0lW!Zt**ESZ7H3^RkS@x@DWc+@xN;Z$qSp$~pTzFQeKQQj3eOY7#vsww zfrVJVzJ%~yk%yB{#aX8hrY(qT#{6&@$rsAgV^MMPr-P@$D(e)VU$VP3&A(ToupbCP z*M$#eS7YaAR}sEz@^Dh|Ct*6Qe#?$83fq%EVYT<6Y2UE=nEDGr>|5`%*bVc;b%gJ> zJe*X&CK6pvzZ%m*>eI@5!41UUq~#+Q8-M4v->L3@xXn=KcMGJxYmBXd{5AJp;&Q0 zyseqrv*rzEPez>AFM0atsRYa^^l@EGxRviZ_qfhJ(ny-u zM!n$~;-Ay<5hET0B2I~y@}aL-JkRLAX3gV+y5GUZJCmOH#GXCbh2OAs>?MNqN*+$p zcpz36`Bfjn&xM)lBb59fy_r+0c>m1)c6aHa5^NrPjqtsZhm+s{(eq~Xz-@}JW6-(m z(owxSEY2!N^;F2MFB*tlrFxtW`1?jiyLdES`*!Gw!sDIaG;hXL4g7uMJA|)V9!?ky ze4Ejsrx!PqR7>Lbj$e<9!vw$EYI6Ma8cUzCH*FP{;t zCPe?9cS}}O6xK(Y!QcCUFG&8YOtbzZOJaer^L5xC^9_sZ>XA^xA7Pn9R4nkb3f(XA z%dv6G*gowKJD=|<#@ffImunXp*Lw@rKffVZjWY25By+hyxHJ|G{QOmMoVFRdBLM&Y zni1#V-e0Rv5b$>#@*RO{`2$WSp+n*HXI*)}UG%Xpbv~gc&1<8~=QhOPM$fP3N^rmb z<%qw3PePnA92;NuQQ`5xx8UAvw5Z4WizaGc2YGg3O#tE;9NBWW0b9qkki51$J?c9i zLOpI}>4!%OuVX~oSeEaQV#UvG__h<*>2JcuZAXN!lRTU(HW2f_nz{bGhuP&XH(AP?j2a8X#bI$Vxtj3zdjvfsZ-^*eT^smn>AfRN z32fceM}h`)1xYWa;yPkmr?hJkpA_%cE@6Fwr=L)~UoStEYuvpW>(|on%DTwI$)a$+ zU%~4Cxl!>rWXbW4PFFr*aRH;>t)u-X&-sq^m#zq3cSPxMwW)KPuGMejl%X?ip74Zj zNYI$BAd$+W#!-OZZk0)&6(2{ryIKnDpJC&Hv3(oAq<;#>-)r|k_)O&CjG}Os-dp6| zRHe8*e=WN)>EnCFajL)d)$@CZ^-mVU=Zq*FE;Sy~XWp?Z?+KnXKLmjp611f&NTf=l zL3DQMlKngr_qPflBJ>M>9v^u7;);%=sn?b z0XDz%MED#KrNgD-iD^gAX*B6V2b$mQzzzx8(-kCAF4X>^V3KX()o+ToPb6@yOI?TI zWNcs71-~LseBO0L_?+b7tfO%9!otEDmH*t2AZY;e@9l}%MHW>_YFJ#yh?6)tZTRB=0UyuY}sC#uls>?8ncS+>m@Pd3r$UU7xea zd7k3Gn;>qQt==7}`0qCf(Tw7<<9o67!2{v*l!uc~jia#(t z_Whj=uP2z|?*sZHd;{g-)JX)QZswps>lSRD_eb&rp%C) zNIJy-W52oZxL@6`Mhs^lsvt;ieN7}kx?sD>yUu^^&ym!}d@lfB$g2U+wl5!oYJ|bYK&DQ^6ze$vT%&M7(`H!)m zO|{g$d=$S=Ck){Wmxr^CI$yHWQ%x{wQhfYHjj@S3h4*Vloa^r|JJIha*1k~)-*~w= zlc@LFCN^icYl0Td57rQYkt2*Yt(MWI%T|x5nl1V^RIdoSU{6lfO zgj$&yi534{0g-Jy((~94Y(5)@@I}hQ$(l%w2d534;~TL!gwYQt*B@lZf5!a4s8@1v z=|qeFVg5`wJUQnzZWlIcSCqW=4W0~;ct?GD5-1+;b~M;>z*F&fmw&n>BrV4ai<5;3 zE&`QN9-gw-v}FrJYz3N0+&CUPCLuA2y3S84z3r|%#Hpcc@u(a4cnQ?)cnPVkGV0Qs z?}P&Cvbwe#+t6)%LO~bFWqaIM_FpBu*4y3^22vnsSH&f`%eg3ySRhOYNC*v5r|2Mx z>QHoWt7e?cC3A8$jVQ?>2@ydoiUv}ewiFGbc$Wpj%pmU|r63*ZB0B0~xg24FSL;1= zg7#UQq)h6Pzt$V()X~tKqoW>G9A^%t`f7LPE-(Opjn!&P`lx_U!?N(92O7woY3F%e z>T1ZxlPiCZTK0zIlXG6SXgwvnoJbwMoArp?eZ*iruda;LO8RIw4r5uNhR;M12wTK~z#bl=?!;*fjFp-M1va z%J{p(vCrhrQxD_kn|~y)@0xw%fc9xJDSK~lZO$jsz{adnZA~57P#rxfb87`DJmy|? zWk(ZPF>&w64msb*+Wgnu^zE9+q}|zbYF^ipGZ)W@KDWXo1C{s zv#ihYucXnSy5Kjv)F6G`+nryv+e0)MUN_2D2b=6P&zFI!;Q33l@mQV~6n@DaHTQ!W zw0pRa)%m_E^s8LcsO+f@n{GQg73HhI;G(pPpHqI5>*gkmOrdanJl4-Dvqb~mo^;Xg zn579OytTguduTwt`Rz*`D>{M8#v7*&JGY1FHA**H7VARM53Osd)x z*MV303s~I;>Oz6wd&Tal9iXeap<{<{x}Y<8@;$DyCLGO+9k<5O5Qd(#9)GS-4~iWw zSI*7o0KHFtpS<#32Z-FcWIWHX1DFllH_GONKKwI`we7N(J{%1R5g(sy08h9kqbw$M zhci{h!?v3DfSTFOMxvX>Fn{>w1=-V0z$mIP+Oee@92|40XU0(zsP^2dwt?sjcePVJ z$|H5b%GcKOz_m_bezfXXot6Qt3++)7)Ta{+y3@Q~-#`njj?F!6__ICioja{kZ?ZN# zIB{=eWrPlJPNv@P6x|7|roOm3W1KF;_58N7i=RHMQ@=Dod8IC8=doks>-0b=N6=yV zUOl*8VY@fnM;}b*|J>lKs|yJ?#6}~{G+=)Uv1P2T9`ru3DRuf>Lx|9d$ttcfgr^Zr zeM6QRz$ep&FEP_PgT~VfJCm*(0PA=9j{3d^@a4GX?zh7Y!BPVpTn8A#Sm(@NEYB{$ z=NjE?rt0y2#dS~*ru_KfW_+!+A>88#g6uquz~S(wzLz?7fhF4tJFB@H!5IJCly7G{ zgV$>3$YV#lLeJ?-H>C_Q2F;1PI^ELh3go?)gPU*lfH{Z0HLiKx4YvL=D7kOm4T6i$ zJznsr8}xMxJ{5Va3q-iyIC1v3G4OZPSRM834obu3-o3Zd7&iEkaXv0Am~tyIe8B)y zQ2Fw5lz}q~;(d!O3mQ!zczWur)7Q-5kmEwFOkZ=D;g`4KO@IlU9cGk&($pB1Di{4a zOyT!maio$x-2nP5{Ze}Sjvm&R-duL8Y4~Y0wbf<6}3+zMBem4&{1(SY@u9T&j!_-d`x70MTz{9IkSz>Ju z2+aBsf97*{xc%IEQu5gz)bqD|+73lk3LMq~>Aw@VeBz@jZZus9iYulQ{$|=Gd>Xv4Y)Qh8nJoHwV)m6XTuwSwq~h zGq064*Z}MO2o1ASR-kY3E3Zb`67tVeSN0Z|!@^Tn$_G3%gMlJLV!4C%j0N1 zRqqM6`AZ@?KeB~DvkQ;PANPbO^WW?2xMmL`&Q z{pYm(Jv+F6*wttuwPhJ~p36sA) z3!Jyl4CX!QM%0B+Hz;bJ(z$1A3Nyy2H~YRc1}{(6+N@?%_H?>N{T?vL8+U%1*FYKTpGV_MAMY4w-%y`ycwtxPjko!SZNuiKN?Zjm(z zcX}Ps3AKXS5x2u8#@RwajH}IUeOm}SI&9jP%APQ#Wz#WghJknco^74_&<2J!EIy#O z%o@24$v-Z#CB;5a~T$b)mE`q;xiucLQF{$me@Yb@0V-slMj zPNsSOE_Mc|bK4K+)H{RK5%S{HT~s|ejae%3u3)inwySuk3rzWSZ*-sDPEfneYy8aN z4iI3O^{>TmConA5Y~WY80Bgr|6RY_yz;~@*9QVi-R`%I-IE?28`jrbGZjE(?Q3KoW zb@q0FmhLCY3$MGv@+a)K?6bXK{Fx38FJHTXPPk!F&o1t;$Z(}n!&(nGv)piuP&dTb*PCTJ`G7^ntlGQ1ymEZI{A-*gne^m!B*jJpod-FqY za4dLvttQ_aEX2h8p-tYf~*D?8j^=6wG(LLXO{xBAAiXWN`$sef^mYMdjS%h`E3q_+#CIxTz^zRnH$ zYOL#RUgipYj3?YlyXOl1Q;cMGChK(a=?!6X4~&i~@qwH@-7O7A^?|8bX)hyU`@zEUgC35 zX2iuNfYMN}8*ftpmfR^d-k=Sj_4AfrNJL3f4%#B>J_z2;x*W3X?Epx!zxAO$YzPPwBB)TA3Yw|QnkC_KX0$+ujqa3Y zXK~Y8ul1&4Xcs3pcPBR?RVuwcyzN@(1RX&~cI(yHZMSZ>USCe_9jNQfxql~D3;i|K z-l`dNbm*X-VFFHmF7@^qF`Oh$>!sES5&r50L_-7qr%=pUmZ? musP8gsY#hUE}NU4PF+%-Fr`&ce^$0#sp?0GU}k2t)BgcsG}_Vt literal 0 HcmV?d00001 diff --git a/tests/test_serve_model.py b/tests/test_serve_model.py index ff3bb89..df60c21 100644 --- a/tests/test_serve_model.py +++ b/tests/test_serve_model.py @@ -1,23 +1,39 @@ """ Tests for web API. """ +import pickle +from subprocess import run +from unittest.mock import patch + +from bodywork_pipeline_utils.aws import Model from fastapi.testclient import TestClient +from numpy import array from pipeline.serve_model import app test_client = TestClient(app) +def wrapped_model() -> Model: + with open("tests/resources/model.pkl", "r+b") as file: + wrapped_model = pickle.load(file) + return wrapped_model + + +@patch("pipeline.serve_model.wrapped_model", new=wrapped_model(), create=True) def test_web_api_returns_valid_response_given_valid_data(): prediction_request = {"product_code": "SKU001", "orders_placed": 100} prediction_response = test_client.post( "/api/v0.1/time_to_dispatch", json=prediction_request ) + model_obj = wrapped_model() + expected_prediction = model_obj.model.predict(array([[100, 0]])).tolist()[0] assert prediction_response.status_code == 200 - assert "est_hours_to_dispatch" in prediction_response.json().keys() - assert "model_version" in prediction_response.json().keys() + assert prediction_response.json()["est_hours_to_dispatch"] == expected_prediction + assert prediction_response.json()["model_version"] == str(model_obj) +@patch("pipeline.serve_model.wrapped_model", new=wrapped_model(), create=True) def test_web_api_returns_error_code_given_invalid_data(): prediction_request = {"product_code": "SKU001", "foo": 100} prediction_response = test_client.post( @@ -25,3 +41,26 @@ def test_web_api_returns_error_code_given_invalid_data(): ) assert prediction_response.status_code == 422 assert "value_error.missing" in prediction_response.text + + prediction_request = {"product_code": "SKU000", "orders_placed": 100} + prediction_response = test_client.post( + "/api/v0.1/time_to_dispatch", json=prediction_request + ) + assert prediction_response.status_code == 422 + assert "not a valid enumeration member" in prediction_response.text + + prediction_request = {"product_code": "SKU001", "orders_placed": -100} + prediction_response = test_client.post( + "/api/v0.1/time_to_dispatch", json=prediction_request + ) + assert prediction_response.status_code == 422 + assert "ensure this value is greater than or equal to 0" in prediction_response.text + + +def test_web_server_raises_exception_if_passed_invalid_args(): + process = run( + ["python", "-m", "pipeline.serve_model"], capture_output=True, encoding="utf-8" + ) + assert process.returncode != 0 + assert "ERROR" in process.stdout + assert "Invalid arguments passed to serve_model.py" in process.stdout diff --git a/tests/test_train_model.py b/tests/test_train_model.py index b6c3cda..d1bf812 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -171,35 +171,40 @@ def test_run_job_handles_error_for_invalid_args(): ) assert process_one.returncode != 0 assert "ERROR" in process_one.stdout + assert "Invalid arguments passed to train_model.py" in process_one.stdout process_two = run( - ["python", "pipeline/train_model.py", "my-bucket", "-1", "0.5"], + ["python", "-m", "pipeline.train_model", "my-bucket", "-1", "0.5"], capture_output=True, encoding="utf-8" ) assert process_two.returncode != 0 assert "ERROR" in process_two.stdout + assert "Invalid arguments passed to train_model.py" in process_two.stdout process_three = run( - ["python", "pipeline/train_model.py", "my-bucket", "2", "0.5"], + ["python", "-m", "pipeline.train_model", "my-bucket", "2", "0.5"], capture_output=True, encoding="utf-8" ) assert process_three.returncode != 0 assert "ERROR" in process_three.stdout + assert "Invalid arguments passed to train_model.py" in process_three.stdout - process_two = run( - ["python", "pipeline/train_model.py", "my-bucket", "0.5", "-1"], + process_four = run( + ["python", "-m", "pipeline.train_model", "my-bucket", "0.5", "-1"], capture_output=True, encoding="utf-8" ) - assert process_two.returncode != 0 - assert "ERROR" in process_two.stdout + assert process_four.returncode != 0 + assert "ERROR" in process_four.stdout + assert "Invalid arguments passed to train_model.py" in process_four.stdout - process_three = run( - ["python", "pipeline/train_model.py", "my-bucket", "0.5", "2"], + process_five = run( + ["python", "-m", "pipeline.train_model", "my-bucket", "0.5", "2"], capture_output=True, encoding="utf-8" ) - assert process_three.returncode != 0 - assert "ERROR" in process_three.stdout + assert process_five.returncode != 0 + assert "ERROR" in process_five.stdout + assert "Invalid arguments passed to train_model.py" in process_five.stdout