Skip to content

Commit

Permalink
Merge pull request #1 from SermetPekin/ref
Browse files Browse the repository at this point in the history
Refactored
  • Loading branch information
SermetPekin authored Dec 2, 2024
2 parents 585bca7 + adeffab commit ac797d3
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 91 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ __pycache__/
--*.*
--*/

ignore*.py

!example.env

test_pri_*.*
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
}



logger = logging.getLogger(__name__)


Expand Down
4 changes: 1 addition & 3 deletions evdschat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from evdschat.core.chat import chat, chat_console

__all__ = [
chat, chat_console
]
__all__ = [chat, chat_console]
89 changes: 74 additions & 15 deletions evdschat/common/akeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from abc import ABC
import os
from pathlib import Path
from typing import Union
from typing import Union, Dict
import time

from evdschat.common.globals import WARNING_SLEEP_SECONDS


class ErrorApiKey(Exception):
Expand All @@ -33,13 +36,21 @@ def __init__(self, message="There is an issue with the provided API key."):


class ApiKey(ABC):
def __init__(self, key: str) -> None:
def __init__(self, key: str, key_name: str = 'ApiKey') -> None:
self.key = key
self.key_name = key_name
self.check()

def __str__(self):
return self.key

def msg_before_raise(self):
showApiKeyMessage(self.__class__.__name__)
# create_env_example_file()

def check(self):
if isinstance(self.key, type(None)):
raise ErrorApiKey('Api key not set. Please see the documentation.')
raise ErrorApiKey("Api key not set. Please see the documentation.")
if not isinstance(self.key, str) or len(str(self.key)) < 5:
raise ErrorApiKey(f"Api key {self.key} is not a valid key")
return True
Expand All @@ -51,16 +62,52 @@ def set_key(self, key: str):
self.key = key


def sleep(number: int):
time.sleep(number)


def showApiKeyMessage(cls_name: str) -> None:
msg = f"""
{cls_name} not found.
create `.env` file and put necessary API keys for EVDS and {cls_name}
see documentation for details.
"""

print(msg)
sleep(WARNING_SLEEP_SECONDS)


def write_env_example(file_name: Path):
content = (
"\nOPENAI_API_KEY=sk-proj-ABCDEFGIJKLMNOPQRSTUXVZ\nEVDS_API_KEY=ABCDEFGIJKLMNOP"
)
with open(file_name, "w") as f:
f.write(content)
print("Example .env file was created.")
sleep(WARNING_SLEEP_SECONDS)


def create_env_example_file():
file_name = Path(".env")
if not file_name.exists():
write_env_example(file_name)


class OpenaiApiKey(ApiKey):
def __init__(self, key: str) -> None:
super().__init__(key)
self.key = key
self.check()

def check(self) -> Union[bool, None]:
self.key_name = 'openai_api_key'
# self.check()

def check(self, raise_=True) -> Union[bool, None]:
if not str(self.key).startswith("sk-") and len(str(self.key)) < 6:
raise ErrorApiKey(f"{self.key} is not a valid key")
self.msg_before_raise()
if raise_:
raise ErrorApiKey(f"{self.key} is not a valid key")
return False
return True


Expand All @@ -70,7 +117,7 @@ class EvdsApiKey(ApiKey): ...
class MistralApiKey(ApiKey): ...


def load_api_keys() -> Union[dict[str, str], None]:
def load_api_keys() -> Dict[str, OpenaiApiKey | EvdsApiKey]:
from dotenv import load_dotenv

env_file = Path(".env")
Expand All @@ -83,19 +130,31 @@ def load_api_keys() -> Union[dict[str, str], None]:
}


def get_openai_key():
def load_api_keys_string() -> Dict[str, str]:
from dotenv import load_dotenv

env_file = Path(".env")
load_dotenv(env_file)
openai_api_key = os.getenv("OPENAI_API_KEY")
evds_api_key = os.getenv("EVDS_API_KEY")
return {
"OPENAI_API_KEY": openai_api_key,
"EVDS_API_KEY": evds_api_key,
}


def get_openai_key() -> OpenaiApiKey:
d = load_api_keys()
return d["OPENAI_API_KEY"].key
return d["OPENAI_API_KEY"]


def get_openai_key_string() -> str | None:
d = load_api_keys_string()
return d["OPENAI_API_KEY"]

# @dataclass

class ApiKeyManager(BaseModel):
api_key: ApiKey = Field(default_factory=lambda: ApiKey())

class Config:
arbitrary_types_allowed = True

# def __get_pydantic_core_schema__(cls, handler):
# # Generate a schema if necessary, or skip it
# return handler.generate_schema(cls)
31 changes: 17 additions & 14 deletions evdschat/common/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,47 @@
from pathlib import Path
from importlib import resources
from typing import Union
from .github_actions import PytestTesting
from .github_actions import PytestTesting


class PostParams(ctypes.Structure):
_fields_ = [
("url", ctypes.c_char_p),
("prompt", ctypes.c_char_p),
("api_key", ctypes.c_char_p),
("proxy_url", ctypes.c_char_p)
("proxy_url", ctypes.c_char_p),
]

def get_exec_file(test = False ) -> Path :

def get_exec_file(test=False) -> Path:

executable_name = "libpost_request.so"
if platform.system() == "Windows":
executable_name = "libpost_request.dll"

if test or PytestTesting().is_testing():
executable_path = Path(".") / executable_name
if executable_path.is_file() :
if executable_path.is_file():
return executable_path
return False

def check_c_executable(test = False ) -> Union[Path, bool]:
executable_name= get_exec_file(test )
return False


def check_c_executable(test=False) -> Union[Path, bool]:
executable_name = get_exec_file(test)
if not executable_name:
return False
return False
try:
with resources.path("evdschat", executable_name) as executable_path:
if executable_path.is_file() and os.access(executable_path, os.X_OK):
return executable_path
except FileNotFoundError:
return False


lib_path = check_c_executable()
if lib_path:
lib = ctypes.CDLL(lib_path)

lib.post_request.argtypes = [ctypes.POINTER(PostParams)]
lib.post_request.restype = ctypes.c_char_p

Expand All @@ -54,16 +58,15 @@ def c_caller(params):

def c_caller_main(prompt, api_key, url, proxy=None):
prompt = prompt.replace("\n", " ")

params = PostParams(
url=url.encode("utf-8"),
prompt=prompt.encode("utf-8"),
api_key=api_key.encode("utf-8"),
proxy_url=proxy.encode("utf-8") if proxy else None
proxy_url=proxy.encode("utf-8") if proxy else None,
)

return c_caller(params)



else:
c_caller_main = None
8 changes: 7 additions & 1 deletion evdschat/common/github_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
# limitations under the License.

import sys


class GithubActions:
def is_testing(self):
return "hostedtoolcache" in sys.argv[0]


class PytestTesting:
def is_testing(self):
# print(" sys.argv[0]" , sys.argv[0])
return "pytest" in sys.argv[0]


def get_input(msg, default=None):
if GithubActions().is_testing() or PytestTesting().is_testing():
if not default:
print("currently testing with no default ")
return False
return default
return input(msg)
return input(msg)
3 changes: 3 additions & 0 deletions evdschat/common/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

WARNING_SLEEP_SECONDS = 6
DEFAULT_CHAT_API_URL = "https://evdspychat-dev2-1.onrender.com/api/ask"


def global_mock():
template = """
Expand Down
1 change: 0 additions & 1 deletion evdschat/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from pydantic import BaseModel, Field

Expand Down
19 changes: 12 additions & 7 deletions evdschat/core/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class GotUndefinedResult(BaseException): ...


def chat(
prompt: str,
getter: ModelAbstract = OpenAI(),
debug=False,
test=False,
force=False,
prompt: str,
getter: ModelAbstract = None,
debug=False,
test=False,
force=False,
) -> Union[Tuple[ResultChat, Notes], None]:
"""
Function to process the chat prompt and return the result.
Expand All @@ -45,6 +45,11 @@ def chat(
:return: DataFrame or Result Instance with .data (DataFrame), .metadata (DataFrame), and .to_excel (Callable).
"""

if getter is None:
getter = OpenAI()



if not force and PytestTesting().is_testing():
test = True

Expand Down Expand Up @@ -74,8 +79,8 @@ def chat(
raise GotUndefinedResult()
result, notes = res
if isinstance(result, ResultChat):
return result, notes
raise NotImplementedError("Unknown Result type ")
return result, notes
raise NotImplementedError("Unknown Result type ")


def chat_console() -> None:
Expand Down
Loading

0 comments on commit ac797d3

Please sign in to comment.