import kfp.components as comp
from collections import OrderedDict
from kubernetes import client as k8s_client

{# PIPELINE FUNCTION BLOCKS #}
{% for func in block_functions -%}
{{func}}
{% endfor -%}

{# DEFINE PIPELINE TASKS FROM FUNCTIONS #}
{%- for name in block_functions_names -%}
{% if docker_base_image != '' %}
{{ name }}_op = comp.func_to_container_op({{ name }}, base_image='{{ docker_base_image }}')
{% else %}
{{ name }}_op = comp.func_to_container_op({{ name }})
{% endif %}
{% endfor -%}

{# DECLARE PIPELINE #}
import kfp.dsl as dsl
@dsl.pipeline(
   name='{{ pipeline_name }}',
   description='{{ pipeline_description }}'
)
def auto_generated_pipeline({{ pipeline_arguments }}):
    pvolumes_dict = OrderedDict()

    {% for vol in volumes -%}
    {% set name= vol['name'] %}
    {% set mountpoint = vol['mount_point'] %}
    {% set pvc_size = vol['size']|default ('') + vol['size_type']|default ('') %}
    {% set annotations = vol['annotations']|default({}) %}
    annotations = {{ annotations }}

    {% if vol['type'] == 'pv' %}

    pvc{{ loop.index }}  = k8s_client.V1PersistentVolumeClaim(
        api_version="v1",
        kind="PersistentVolumeClaim",
        metadata=k8s_client.V1ObjectMeta(
            name="{{ name }}-claim-{{ pipeline_name }}"
        ),
        spec=k8s_client.V1PersistentVolumeClaimSpec(
            volume_name="{{ name }}",
            access_modes=['ReadWriteOnce'],
            resources=k8s_client.V1ResourceRequirements(
                requests={"storage": {{ pvc_size }}}
            )
        )
    )                                                              

    vop{{ loop.index }} = dsl.VolumeOp(
        name="pvc-data{{ loop.index }}",
        annotations=annotations,
        k8s_resource=pvc{{ loop.index }}
    )
    volume = vop{{ loop.index }}.volume

    {% elif vol['type'] == 'pvc' %}

    volume = dsl.PipelineVolume(pvc=vol_{{ mountpoint.replace('/', '_').strip('_') }})

    {% elif vol['type'] == 'new_pvc' %}
    {% if annotations.get('rok/origin') %}
    annotations['rok/origin'] = rok_{{ name.replace('-', '_') }}_url
    {% endif %}

    vop{{ loop.index }} = dsl.VolumeOp(
        name='create-volume-{{ loop.index }}',
        resource_name='{{ name }}',
        {%- if annotations %}
        annotations=annotations,
        {% endif -%}
        size='{{ pvc_size }}'
    )
    volume = vop{{ loop.index }}.volume

    {% endif %}

    pvolumes_dict['{{ mountpoint }}'] = volume

    {% endfor %}

    {% if marshal_volume %}
    marshal_vop = dsl.VolumeOp(
        name="kale_marshal_volume",
        resource_name="kale-marshal-pvc",
        modes=dsl.VOLUME_MODE_RWM,
        size="1Gi"
    )
    pvolumes_dict['{{ marshal_path }}'] = marshal_vop.volume
    {% endif %}

    {% for name in block_functions_names %}

    {{ name }}_task = {{ name }}_op({{ pipeline_arguments_names }})\
                            .add_pvolumes(pvolumes_dict)\
                            .after({{ block_function_prevs[ name ]|join(', ') }})
    {{ name }}_task.container.working_dir = "{{ working_dir }}"
    {{ name }}_task.container.set_security_context(k8s_client.V1SecurityContext(run_as_user=0))
    {% if auto_snapshot -%}
    mlpipeline_ui_metadata = {'mlpipeline-ui-metadata': '/mlpipeline-ui-metadata.json'}
    {{ name }}_task.output_artifact_paths.update(mlpipeline_ui_metadata)
    {% endif -%}

    {% endfor %}

    {# Snaphosts #}
    {% for vol in volumes -%}
    {% if vol['snapshot'] %}
    snapshot{{ loop.index }} = dsl.VolumeSnapshotOp(
        name='snapshot-volume-{{ loop.index }}',
        resource_name='{{ vol['snapshot_name'] }}',
        volume=vop{{ loop.index }}.volume.after({{ leaf_nodes| map('add_suffix', '_task') | join(', ') }})
    )
    {% endif %}
    {% endfor %}


{# The script will deploy the pipeline if run manually #}
if __name__ == "__main__":
    pipeline_func = auto_generated_pipeline
    pipeline_filename = pipeline_func.__name__ + '.pipeline.tar.gz'
    import kfp.compiler as compiler
    compiler.Compiler().compile(pipeline_func, pipeline_filename)

    # Get or create an experiment and submit a pipeline run
    import kfp
    client = kfp.Client()
    experiment = client.create_experiment('{{ experiment_name }}')

    # Submit a pipeline run
    run_name = '{{ pipeline_name }}_run'
    run_result = client.run_pipeline(experiment.id, run_name, pipeline_filename, {})
