Skip to content

Commit

Permalink
using model arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed May 7, 2024
1 parent 557c982 commit e2db723
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
12 changes: 0 additions & 12 deletions pieces/ProphetTrainModelPiece/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ class InputModel(BaseModel):
title="Input Data File",
description="Path to the input data file. Accepted formats: `.csv`, `.json`. Should use the following format: `ds` (datetime), `y` (target).",
)
# datetime_column_name: str = Field(
# description="Name of the column containing the datetime values."
# )
# target_column_name: str = Field(
# description="Name of the column containing the target values."
# )
test_set_percentage: float = Field(
default=20.0,
ge=1,
le=90,
description="Percentage of the data to use as test set. Default is 20%."
)
growth_trend: GrowthTrend = Field(
default=GrowthTrend.linear,
description="The growth trend of the data. Options are `linear`, `logistic` and `flat`. Default is `linear`."
Expand Down
7 changes: 6 additions & 1 deletion pieces/ProphetTrainModelPiece/piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ def piece_function(self, input_data: InputModel):
else:
raise ValueError("File format not supported. Please pass a CSV or JSON file.")

model = Prophet()
model = Prophet(
seasonality_mode=input_data.seasonality_mode,
growth=input_data.growth_trend,
changepoints=input_data.changepoints,
n_changepoints=input_data.n_changepoints
)
model.fit(df)

# Serialize model
Expand Down

0 comments on commit e2db723

Please sign in to comment.