From 1f5ab8aa53faed8ed04d787bb70538b1f53f79cd Mon Sep 17 00:00:00 2001 From: "Yao, Qing" Date: Thu, 4 Jul 2024 15:48:42 +0800 Subject: [PATCH 1/2] Support RAG init version. --- pyspark_ai/prompt.py | 84 +++++++++++++++++++++++++++++++++++++++- pyspark_ai/pyspark_ai.py | 32 ++++++++++++++- 2 files changed, 113 insertions(+), 3 deletions(-) diff --git a/pyspark_ai/prompt.py b/pyspark_ai/prompt.py index 23fbb291..9196e33e 100644 --- a/pyspark_ai/prompt.py +++ b/pyspark_ai/prompt.py @@ -62,6 +62,61 @@ sql_answer2 = "SELECT COUNT(`Student`) FROM `spark_ai_temp_view_12qcl3` WHERE `Birthday` = 'January 1, 2006'" +sql_question3 = """QUESTION: Given some Spark tables metadata or sqls: +``` +CREATE TABLE Customers ( + customer_id INT, + customer_name STRING, + customer_email STRING +); + +CREATE TABLE Orders ( + order_id INT, + customer_id INT, + order_date DATE, + order_total DECIMAL(10, 2) +); +``` +Write a Spark SQL query to retrieve data based on the provided information: How many orders has each customer placed? +""" + +sql_answer3 = """ +SELECT c.customer_name, COUNT(o.order_id) AS number_of_orders +FROM Customers c +JOIN Orders o ON c.customer_id = o.customer_id +GROUP BY c.customer_name; +""" + + +sql_question4 = """QUESTION: Given some Spark tables metadata or sqls: +``` +CREATE TABLE Products ( + product_id INT, + product_name STRING, + product_price DECIMAL(10, 2), + category STRING +); + +CREATE TABLE Sales ( + sale_id INT, + product_id INT, + sale_date DATE, + quantity INT, + total_sale_amount DECIMAL(10, 2) +); +``` +Write a Spark SQL query to retrieve data based on the provided information: Which product has the highest number of sales in terms of quantity sold? +""" + +sql_answer4 = """ +SELECT p.product_name, SUM(s.quantity) AS total_quantity_sold +FROM Products p +JOIN Sales s ON p.product_id = s.product_id +GROUP BY p.product_name +ORDER BY total_quantity_sold DESC +LIMIT 1; +""" + spark_sql_shared_example_1_prefix = f"""{sql_question1} Thought: The column names are non-descriptive, but from the sample values I see that column `a` contains mountains and column `c` contains countries. So, I will filter on column `c` for 'Japan' and column `a` for the mountain. @@ -180,13 +235,21 @@ Answer: """ +SPARK_SQL_SUFFIX_RAG = """\nQUESTION: Given some Spark tables metadata or sqls: +``` +{comment} +``` +Write a Spark SQL query to retrieve data based on the provided information: {desc} +Answer: +""" + SPARK_SQL_SUFFIX_FOR_AGENT = SPARK_SQL_SUFFIX + "\n{agent_scratchpad}" SPARK_SQL_PREFIX = """You are an assistant for writing professional Spark SQL queries. Given a question, you need to write a Spark SQL query to answer the question. The rules that you should follow for answering question: -1.The answer only consists of Spark SQL query. No explaination. No -2.SQL statements should be Spark SQL query. +1.The answer only consists of Spark SQL query. No explanation. +2.SQL statements should be Spark SQL query. 3.ONLY use the verbatim column_name in your resulting SQL query; DO NOT include the type. 4.Use the COUNT SQL function when the query asks for total number of some non-countable column. 5.Use the SUM SQL function to accumulate the total number of countable column values.""" @@ -239,6 +302,23 @@ prefix=SPARK_SQL_PREFIX, ) +SQL_CHAIN_EXAMPLES_RAG = [ + sql_question3 + f"\nAnswer:\n```{sql_answer3}```", + sql_question4 + f"\nAnswer:\n```{sql_answer4}```", +] + +SQL_CHAIN_PROMPT_RAG = PromptTemplate.from_examples( + examples=SQL_CHAIN_EXAMPLES_RAG, + suffix=SPARK_SQL_SUFFIX_RAG, + input_variables=[ + "view_name", + "sample_vals", + "comment", + "desc", + ], + prefix=SPARK_SQL_PREFIX, +) + EXPLAIN_PREFIX = """You are an Apache Spark SQL expert, who can summary what a dataframe retrieves. Given an analyzed query plan of a dataframe, you will 1. convert the dataframe to SQL query. Note that an explain output contains plan diff --git a/pyspark_ai/pyspark_ai.py b/pyspark_ai/pyspark_ai.py index 799644ca..ff977785 100644 --- a/pyspark_ai/pyspark_ai.py +++ b/pyspark_ai/pyspark_ai.py @@ -12,6 +12,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain_community.chat_models import ChatOpenAI +from langchain_core.vectorstores import VectorStore from pyspark.sql import DataFrame, SparkSession from pyspark_ai.ai_utils import AIUtils @@ -23,6 +24,7 @@ PLOT_PROMPT, SEARCH_PROMPT, SQL_CHAIN_PROMPT, + SQL_CHAIN_PROMPT_RAG, SQL_PROMPT, UDF_PROMPT, VERIFY_PROMPT, @@ -61,6 +63,7 @@ def __init__( enable_cache: bool = True, cache_file_format: str = "json", cache_file_location: Optional[str] = None, + vector_db: VectorStore = None, vector_store_dir: Optional[str] = None, vector_store_max_gb: Optional[float] = 16, max_tokens_of_web_content: int = 3000, @@ -110,6 +113,7 @@ def __init__( ).search else: self._cache = None + self._vector_db = vector_db self._vector_store_dir = vector_store_dir self._vector_store_max_gb = vector_store_max_gb self._max_tokens_of_web_content = max_tokens_of_web_content @@ -136,8 +140,12 @@ def _create_llm_chain(self, prompt: BasePromptTemplate): @property def sql_chain(self): if self._sql_chain is None: + if self._vector_db: + prompt_temp = SQL_CHAIN_PROMPT_RAG + else: + prompt_temp = SQL_CHAIN_PROMPT self._sql_chain = SparkSQLChain( - prompt=SQL_CHAIN_PROMPT, + prompt=prompt_temp, llm=self._llm, logger=self._logger, spark=self._spark, @@ -576,6 +584,16 @@ def _get_transform_sql_query_tpch(self, desc: str, table: str, cache: bool) -> s #print(f"-------------------------Current table comment is-------------------------\n\n {comment}\n") return self._get_sql_query(table, sample_vals_str, comment, desc) + def _get_transform_sql_query_rag(self, desc: str): + docs = self._vector_db.similarity_search(desc) + reference_contents = [] + for doc in docs: + reference_contents.append(doc.page_content) + reference_str = "\n".join([str(val) for val in reference_contents]) + print(f"-------------------------Current reference contents are:-------------------------\n\n {reference_str}\n") + return self._get_sql_query('', '', reference_str, desc) + + def transform_df_tpch(self, desc: str, table: str, cache: bool = False) -> DataFrame: print(f"---------------------TPCH Table {table}------------------------------\n\n") start_time = time.time() @@ -586,6 +604,18 @@ def transform_df_tpch(self, desc: str, table: str, cache: bool = False) -> DataF print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n") return self._spark.sql(sql_query) + def transform_rag(self, desc: str, cache: bool = False) -> DataFrame: + print(f"---------------------Start get_transform_sql_query with rag------------------------------\n\n") + start_time = time.time() + sql_query = self._get_transform_sql_query_rag(desc) + end_time = time.time() + get_transform_sql_query_time = end_time - start_time + print(f"-------------------------End get_transform_sql_query-------------------------\n\n get_transform_sql_query_time: {get_transform_sql_query_time} seconds\n") + print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n") + return self._spark.sql(sql_query) + + + def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFrame: """ This method applies a transformation to a provided Spark DataFrame, From 632209c17a738d559a21ad716ffa958355b261d9 Mon Sep 17 00:00:00 2001 From: "Yao, Qing" Date: Tue, 9 Jul 2024 13:40:35 +0800 Subject: [PATCH 2/2] Add an example for transform_rag within faiss store. --- examples/rag_example.py | 60 +++++++++++++++++++++++++++ examples/tpch.sql | 92 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 examples/rag_example.py create mode 100644 examples/tpch.sql diff --git a/examples/rag_example.py b/examples/rag_example.py new file mode 100644 index 00000000..4f8a36a9 --- /dev/null +++ b/examples/rag_example.py @@ -0,0 +1,60 @@ + +from pyspark_ai import SparkAI +from pyspark.sql import DataFrame, SparkSession + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_community.llms import VLLM + +model_name = "sentence-transformers/all-mpnet-base-v2" +model_kwargs = {'device': 'cpu'} +encode_kwargs = {'normalize_embeddings': False} +hf_embedding = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs +) + +index_name = 'tpch_index' + +def load_tpch_schema(): + with open("tpch.sql","r") as f: + all_lines = f.readlines() + tpch_texts = "".join(all_lines).replace("\n",' ') + tables = tpch_texts.split(";") + return tables + + +def store_in_faiss(texts): + db = FAISS.from_texts(texts, hf_embedding) + db.save_local(index_name) + + +if __name__ == '__main__': + # split TPCH schema and store it in faiss vector store + tpch_texts = load_tpch_schema() + store_in_faiss(tpch_texts) + + db = FAISS.load_local(index_name, hf_embedding) + # Initialize the VLLM + # Arguments for vLLM engine: https://github.com/bigPYJ1151/vllm/blob/e394e2b72c0e0d6e57dc818613d1ea3fc8109ace/vllm/engine/arg_utils.py#L12 + llm = VLLM( + # model="defog/sqlcoder-7b-2", + # model="deepseek-ai/deepseek-coder-7b-instruct-v1.5", + model="microsoft/Phi-3-mini-4k-instruct", + trust_remote_code=True, + download_dir="/mnt/DP_disk2/models/Huggingface/" + ) + + # show reference tables + docs = db.similarity_search("What is the customer's name who has placed the most orders in year of 1995? ") + for doc in docs: + print(doc.page_content) + + spark_session = SparkSession.builder.appName("text2sql").master("local[*]").enableHiveSupport(). getOrCreate() + spark_session.sql("show databases").show() + spark_session.sql("use tpch;").show() + # # Initialize and activate SparkAI + spark_ai = SparkAI(llm=llm,verbose=True,spark_session=spark_session, vector_db=db) + spark_ai.activate() + spark_ai.transform_rag("What is the customer's name who has placed the most orders in year of 1995? ").show() diff --git a/examples/tpch.sql b/examples/tpch.sql new file mode 100644 index 00000000..0fa943c3 --- /dev/null +++ b/examples/tpch.sql @@ -0,0 +1,92 @@ +CREATE TABLE nation +( + n_nationkey INTEGER not null, + n_name CHAR(25) not null, + n_regionkey INTEGER not null, + n_comment VARCHAR(152) +); + +CREATE TABLE region +( + r_regionkey INTEGER not null, + r_name CHAR(25) not null, + r_comment VARCHAR(152) +); + +CREATE TABLE part +( + p_partkey BIGINT not null, + p_name VARCHAR(55) not null, + p_mfgr CHAR(25) not null, + p_brand CHAR(10) not null, + p_type VARCHAR(25) not null, + p_size INTEGER not null, + p_container CHAR(10) not null, + p_retailprice DOUBLE PRECISION not null, + p_comment VARCHAR(23) not null +); + +CREATE TABLE supplier +( + s_suppkey BIGINT not null, + s_name CHAR(25) not null, + s_address VARCHAR(40) not null, + s_nationkey INTEGER not null, + s_phone CHAR(15) not null, + s_acctbal DOUBLE PRECISION not null, + s_comment VARCHAR(101) not null +); + +CREATE TABLE partsupp +( + ps_partkey BIGINT not null, + ps_suppkey BIGINT not null, + ps_availqty BIGINT not null, + ps_supplycost DOUBLE PRECISION not null, + ps_comment VARCHAR(199) not null +); + +CREATE TABLE customer +( + c_custkey BIGINT not null, + c_name VARCHAR(25) not null, + c_address VARCHAR(40) not null, + c_nationkey INTEGER not null, + c_phone CHAR(15) not null, + c_acctbal DOUBLE PRECISION not null, + c_mktsegment CHAR(10) not null, + c_comment VARCHAR(117) not null +); + +CREATE TABLE orders +( + o_orderkey BIGINT not null, + o_custkey BIGINT not null, + o_orderstatus CHAR(1) not null, + o_totalprice DOUBLE PRECISION not null, + o_orderdate DATE not null, + o_orderpriority CHAR(15) not null, + o_clerk CHAR(15) not null, + o_shippriority INTEGER not null, + o_comment VARCHAR(79) not null +); + +CREATE TABLE lineitem +( + l_orderkey BIGINT not null, + l_partkey BIGINT not null, + l_suppkey BIGINT not null, + l_linenumber BIGINT not null, + l_quantity DOUBLE PRECISION not null, + l_extendedprice DOUBLE PRECISION not null, + l_discount DOUBLE PRECISION not null, + l_tax DOUBLE PRECISION not null, + l_returnflag CHAR(1) not null, + l_linestatus CHAR(1) not null, + l_shipdate DATE not null, + l_commitdate DATE not null, + l_receiptdate DATE not null, + l_shipinstruct CHAR(25) not null, + l_shipmode CHAR(10) not null, + l_comment VARCHAR(44) not null +); \ No newline at end of file