diff --git a/impc_etl/jobs/load/impc_kg/gene_phenotype_association_mapper.py b/impc_etl/jobs/load/impc_kg/gene_phenotype_association_mapper.py index d76b2cd2..3a9af925 100644 --- a/impc_etl/jobs/load/impc_kg/gene_phenotype_association_mapper.py +++ b/impc_etl/jobs/load/impc_kg/gene_phenotype_association_mapper.py @@ -1,94 +1,79 @@ -import luigi -from impc_etl.jobs.load.impc_bulk_api.impc_api_mapper import ( - to_camel_case, - ImpcGenePhenotypeHitsMapper, -) -from luigi.contrib.spark import PySparkTask -from pyspark import SparkContext -from pyspark.sql import SparkSession - -from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id -from impc_etl.workflow.config import ImpcConfig +""" +Module to generate the gene-phenotype association data as JSON for the KG. +""" +import logging +import textwrap +from airflow.sdk import Variable, asset -class ImpcKgGenePhenotypeAssociationMapper(PySparkTask): - """ - PySpark Task class to parse GenTar Product report data. - """ +from impc_etl.utils.airflow import create_input_asset, create_output_asset +from impc_etl.utils.spark import with_spark_session - #: Name of the Spark task - name: str = "ImpcKgGenePhenotypeAssociationMapper" +task_logger = logging.getLogger("airflow.task") +dr_tag = Variable.get("data_release_tag") - #: Path of the output directory where the new parquet file will be generated. - output_path: luigi.Parameter = luigi.Parameter() +genotype_phenotype_hits_json_asset = create_input_asset("output/impc_web_api/gene_phenotype_hits_service_json") - def requires(self): - return [ImpcGenePhenotypeHitsMapper()] +gene_phenotype_association_output_asset = create_output_asset("/impc_kg/gene_phenotype_association_json") - def output(self): +@asset.multi( + schedule=[genotype_phenotype_hits_json_asset], + outlets=[gene_phenotype_association_output_asset], + dag_id=f"{dr_tag}_impc_kg_gene_phenotype_association_mapper", + description=textwrap.dedent( """ - Returns the full parquet path as an output for the Luigi Task - (e.g. impc/dr15.2/parquet/product_report_parquet) + PySpark task to create the Knowledge Graph JSON files for + gene-phenotype associations from the impc_web_api gene_phenotype_hits_service_json data. """ - return ImpcConfig().get_target( - f"{self.output_path}/impc_kg/gene_phenotype_association_json" - ) + ), + tags=["impc_kg"], +) +@with_spark_session +def impc_kg_gene_phenotype_association_mapper(): - def app_options(self): - """ - Generates the options pass to the PySpark job - """ - return [ - self.input()[0].path, - self.output().path, - ] + from impc_etl.jobs.load.impc_web_api.impc_web_api_helper import to_camel_case + from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id - def main(self, sc: SparkContext, *args): - """ - Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job. - """ - spark = SparkSession(sc) + from pyspark.sql import SparkSession - # Parsing app options - input_parquet_path = args[0] - output_path = args[1] + spark = SparkSession.builder.getOrCreate() - input_df = spark.read.json(input_parquet_path) - input_df = add_unique_id( - input_df, - "parameter_id", - ["pipelineStableId", "procedureStableId", "parameterStableId"], - ) - input_df = add_unique_id( - input_df, "phenotyping_center_id", ["phenotypingCentre"] - ) - input_df = add_unique_id(input_df, "mouse_gene_id", ["mgiGeneAccessionId"]) - input_df = add_unique_id(input_df, "mouse_allele_id", ["alleleAccessionId"]) - input_df = input_df.withColumnRenamed("id", "genePhenotypeAssociationId") - input_df = input_df.withColumnRenamed("datasetId", "statisticalResultId") - output_cols = [ - "genePhenotypeAssociationId", - "alleleAccessionId", - "phenotyping_center_id", - "statisticalResultId", - "effectSize", - "lifeStageName", - "mouse_gene_id", - "pValue", - "parameter_id", - "phenotype", - "projectName", - "sex", - "zygosity", - "mouse_allele_id", - ] - output_df = input_df.select(*output_cols).distinct() - for col_name in output_df.columns: - output_df = output_df.withColumnRenamed( - col_name, - to_camel_case(col_name), - ) - output_df.coalesce(1).write.json( - output_path, mode="overwrite", compression="gzip" + input_df = spark.read.json(genotype_phenotype_hits_json_asset.uri) + input_df = add_unique_id( + input_df, + "parameter_id", + ["pipelineStableId", "procedureStableId", "parameterStableId"], + ) + input_df = add_unique_id( + input_df, "phenotyping_center_id", ["phenotypingCentre"] + ) + input_df = add_unique_id(input_df, "mouse_gene_id", ["mgiGeneAccessionId"]) + input_df = add_unique_id(input_df, "mouse_allele_id", ["alleleAccessionId"]) + input_df = input_df.withColumnRenamed("id", "genePhenotypeAssociationId") + input_df = input_df.withColumnRenamed("datasetId", "statisticalResultId") + output_cols = [ + "genePhenotypeAssociationId", + "alleleAccessionId", + "phenotyping_center_id", + "statisticalResultId", + "effectSize", + "lifeStageName", + "mouse_gene_id", + "pValue", + "parameter_id", + "phenotype", + "projectName", + "sex", + "zygosity", + "mouse_allele_id", + ] + output_df = input_df.select(*output_cols).distinct() + for col_name in output_df.columns: + output_df = output_df.withColumnRenamed( + col_name, + to_camel_case(col_name), ) + output_df.coalesce(1).write.json( + gene_phenotype_association_output_asset.uri, mode="overwrite", compression="gzip" + )