diff --git a/pieces/ProphetTrainModelPiece/models.py b/pieces/ProphetTrainModelPiece/models.py index ccb586e..177d17e 100644 --- a/pieces/ProphetTrainModelPiece/models.py +++ b/pieces/ProphetTrainModelPiece/models.py @@ -20,18 +20,6 @@ class InputModel(BaseModel): title="Input Data File", description="Path to the input data file. Accepted formats: `.csv`, `.json`. Should use the following format: `ds` (datetime), `y` (target).", ) - # datetime_column_name: str = Field( - # description="Name of the column containing the datetime values." - # ) - # target_column_name: str = Field( - # description="Name of the column containing the target values." - # ) - test_set_percentage: float = Field( - default=20.0, - ge=1, - le=90, - description="Percentage of the data to use as test set. Default is 20%." - ) growth_trend: GrowthTrend = Field( default=GrowthTrend.linear, description="The growth trend of the data. Options are `linear`, `logistic` and `flat`. Default is `linear`." diff --git a/pieces/ProphetTrainModelPiece/piece.py b/pieces/ProphetTrainModelPiece/piece.py index 6af429b..c08efea 100644 --- a/pieces/ProphetTrainModelPiece/piece.py +++ b/pieces/ProphetTrainModelPiece/piece.py @@ -20,7 +20,12 @@ def piece_function(self, input_data: InputModel): else: raise ValueError("File format not supported. Please pass a CSV or JSON file.") - model = Prophet() + model = Prophet( + seasonality_mode=input_data.seasonality_mode, + growth=input_data.growth_trend, + changepoints=input_data.changepoints, + n_changepoints=input_data.n_changepoints + ) model.fit(df) # Serialize model