diff --git a/impc_etl/jobs/load/impc_kg/parameter_mapper.py b/impc_etl/jobs/load/impc_kg/parameter_mapper.py index b26ec703..00445437 100644 --- a/impc_etl/jobs/load/impc_kg/parameter_mapper.py +++ b/impc_etl/jobs/load/impc_kg/parameter_mapper.py @@ -1,96 +1,90 @@ -import luigi -from impc_etl.jobs.load.impc_bulk_api.impc_api_mapper import to_camel_case -from luigi.contrib.spark import PySparkTask -from pyspark import SparkContext -from pyspark.sql import SparkSession -from pyspark.sql.functions import col, trim, when, lit +""" +Module to generate the statistical result data as JSON for the KG. +""" +import logging +import textwrap -from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id -from impc_etl.jobs.load.solr.pipeline_mapper import ImpressToParameterMapper -from impc_etl.workflow.config import ImpcConfig +from airflow.sdk import Variable, asset +from impc_etl.utils.airflow import create_input_asset, create_output_asset +from impc_etl.utils.spark import with_spark_session -class ImpcKgParameterMapper(PySparkTask): - """ - PySpark Task class to parse GenTar Product report data. - """ +task_logger = logging.getLogger("airflow.task") +dr_tag = Variable.get("data_release_tag") - #: Name of the Spark task - name: str = "ImpcKgParameterMapper" +impress_parameter_parquet_asset = create_input_asset("output/impress_parameter_parquet") - #: Path of the output directory where the new parquet file will be generated. - output_path: luigi.Parameter = luigi.Parameter() +parameter_output_asset = create_output_asset("/impc_kg/parameter_json") - def requires(self): - return [ImpressToParameterMapper()] - - def output(self): - """ - Returns the full parquet path as an output for the Luigi Task - (e.g. impc/dr15.2/parquet/product_report_parquet) +@asset.multi( + schedule=[impress_parameter_parquet_asset], + outlets=[parameter_output_asset], + dag_id=f"{dr_tag}_impc_kg_parameter_mapper", + description=textwrap.dedent( """ - return ImpcConfig().get_target(f"{self.output_path}/impc_kg/parameter_json") - - def app_options(self): + PySpark task to create the parameter JSON fro the Knowledge Graph + from the output of the IMPReSS parameter mapper. """ - Generates the options pass to the PySpark job - """ - return [ - self.input()[0].path, - self.output().path, - ] - - def main(self, sc: SparkContext, *args): - """ - Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job. - """ - spark = SparkSession(sc) - - # Parsing app options - input_parquet_path = args[0] - output_path = args[1] - - input_df = spark.read.parquet(input_parquet_path) - input_df = add_unique_id( - input_df, - "parameter_id", - ["pipeline_stable_id", "procedure_stable_id", "parameter_stable_id"], - ) - - input_df = input_df.drop("name") - - input_df = input_df.withColumn("unit_x", trim(col("unit_x"))).withColumn( - "unit_x", when(~(col("unit_x") == ""), col("unit_x")).otherwise(lit(None)) - ) - input_df = input_df.withColumn("unit_y", trim(col("unit_y"))).withColumn( - "unit_y", when(~(col("unit_y") == ""), col("unit_y")).otherwise(lit(None)) - ) - input_df = input_df.withColumnRenamed( - "mp_id", - "potentialPhenotypeTermCuries", - ) - - input_df = input_df.withColumnRenamed( - "parameter_name", - "name", - ) - - output_cols = [ - "parameter_id", - "parameter_stable_id", - "name", - "data_type", - "unit_x", - "unit_y", - "potentialPhenotypeTermCuries", - ] - output_df = input_df.select(*output_cols).distinct() - for col_name in output_df.columns: - output_df = output_df.withColumnRenamed( - col_name, - to_camel_case(col_name), - ) - output_df.distinct().coalesce(1).write.option("ignoreNullFields", "false").json( - output_path, mode="overwrite" + ), + tags=["impc_kg"], +) +@with_spark_session +def impc_kg_parameter_mapper(): + + from impc_etl.jobs.load.impc_web_api.impc_web_api_helper import to_camel_case + from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id + + from pyspark.sql import SparkSession + from pyspark.sql.functions import ( + trim, + col, + lit, + when, + ) + + spark = SparkSession.builder.getOrCreate() + + input_df = spark.read.parquet(impress_parameter_parquet_asset.uri) + input_df = add_unique_id( + input_df, + "parameter_id", + ["pipeline_stable_id", "procedure_stable_id", "parameter_stable_id"], + ) + + input_df = input_df.drop("name") + + input_df = input_df.withColumn("unit_x", trim(col("unit_x"))).withColumn( + "unit_x", when(~(col("unit_x") == ""), col("unit_x")).otherwise(lit(None)) + ) + input_df = input_df.withColumn("unit_y", trim(col("unit_y"))).withColumn( + "unit_y", when(~(col("unit_y") == ""), col("unit_y")).otherwise(lit(None)) + ) + input_df = input_df.withColumnRenamed( + "mp_id", + "potentialPhenotypeTermCuries", + ) + + input_df = input_df.withColumnRenamed( + "parameter_name", + "name", + ) + + output_cols = [ + "parameter_id", + "parameter_stable_id", + "name", + "data_type", + "unit_x", + "unit_y", + "potentialPhenotypeTermCuries", + ] + output_df = input_df.select(*output_cols).distinct() + for col_name in output_df.columns: + output_df = output_df.withColumnRenamed( + col_name, + to_camel_case(col_name), ) + output_df.distinct().coalesce(1).write.option("ignoreNullFields", "false").json( + parameter_output_asset.uri, mode="overwrite" + )