Rearrange docker build script.
[alexxy/gromacs.git] / admin / containers / scripted_gmx_docker_builds.py
index 08e16d1c171e3f9c5e6d53d93c8b00d6fc15ccdd..31cb993afffc5e120b00b0d0808a3e75b3159a26 100755 (executable)
@@ -534,8 +534,8 @@ def prepare_venv(version: StrictVersion) -> typing.Sequence[str]:
     return commands
 
 
-def add_python_stages(building_blocks: typing.Mapping[str, bb_base],
-                      input_args,
+def add_python_stages(input_args: argparse.Namespace, *,
+                      base: str,
                       output_stages: typing.MutableMapping[str, hpccm.Stage]):
     """Add the stage(s) necessary for the requested venvs.
 
@@ -555,29 +555,17 @@ def add_python_stages(building_blocks: typing.Mapping[str, bb_base],
     # copy is a bit slow and wastes local Docker image space for each filesystem
     # layer.
     pyenv_stage = hpccm.Stage()
-    pyenv_stage += hpccm.primitives.baseimage(image=base_image_tag(input_args),
+    pyenv_stage += hpccm.primitives.baseimage(image=base,
                                               _distro=hpccm_distro_name(input_args),
                                               _as='pyenv')
-    pyenv_stage += building_blocks['compiler']
-    if building_blocks['gdrcopy'] is not None:
-        pyenv_stage += building_blocks['gdrcopy']
-    if building_blocks['ucx'] is not None:
-        pyenv_stage += building_blocks['ucx']
-    pyenv_stage += building_blocks['mpi']
     pyenv_stage += hpccm.building_blocks.packages(ospackages=_python_extra_packages)
 
     for version in [StrictVersion(py_ver) for py_ver in sorted(input_args.venvs)]:
         stage_name = 'py' + str(version)
         stage = hpccm.Stage()
-        stage += hpccm.primitives.baseimage(image=base_image_tag(input_args),
+        stage += hpccm.primitives.baseimage(image=base,
                                             _distro=hpccm_distro_name(input_args),
                                             _as=stage_name)
-        stage += building_blocks['compiler']
-        if building_blocks['gdrcopy'] is not None:
-            stage += building_blocks['gdrcopy']
-        if building_blocks['ucx'] is not None:
-            stage += building_blocks['ucx']
-        stage += building_blocks['mpi']
         stage += hpccm.building_blocks.packages(ospackages=_python_extra_packages)
 
         # TODO: Use a non-root user for testing and Python virtual environments.
@@ -655,6 +643,33 @@ def add_documentation_dependencies(input_args,
         output_stages['main'] += hpccm.primitives.shell(commands=commands)
 
 
+def add_base_stage(name: str,
+                   input_args,
+                   output_stages: typing.MutableMapping[str, hpccm.Stage]):
+    """Establish dependencies that are shared by multiple parallel stages."""
+    # Building blocks are chunks of container-builder instructions that can be
+    # copied to any build stage with the addition operator.
+    building_blocks = collections.OrderedDict()
+    building_blocks['base_packages'] = hpccm.building_blocks.packages(
+        ospackages=_common_packages)
+
+    # These are the most expensive and most reusable layers, so we put them first.
+    building_blocks['compiler'] = get_compiler(input_args, compiler_build_stage=output_stages.get('compiler_build'))
+    building_blocks['gdrcopy'] = get_gdrcopy(input_args, building_blocks['compiler'])
+    building_blocks['ucx'] = get_ucx(input_args, building_blocks['compiler'], building_blocks['gdrcopy'])
+    building_blocks['mpi'] = get_mpi(input_args, building_blocks['compiler'], building_blocks['ucx'])
+
+    # Create the stage from which the targeted image will be tagged.
+    output_stages[name] = hpccm.Stage()
+
+    output_stages[name] += hpccm.primitives.baseimage(image=base_image_tag(input_args),
+                                                      _distro=hpccm_distro_name(input_args),
+                                                      _as=name)
+    for bb in building_blocks.values():
+        if bb is not None:
+            output_stages[name] += bb
+
+
 def build_stages(args) -> typing.Iterable[hpccm.Stage]:
     """Define and sequence the stages for the recipe corresponding to *args*."""
 
@@ -675,17 +690,17 @@ def build_stages(args) -> typing.Iterable[hpccm.Stage]:
     if args.oneapi is not None:
         add_oneapi_compiler_build_stage(input_args=args, output_stages=stages)
 
+    add_base_stage(name='build_base', input_args=args, output_stages=stages)
+
+    # Add Python environments to MPI images, only, so we don't have to worry
+    # about whether to install mpi4py.
+    if args.mpi is not None and len(args.venvs) > 0:
+        add_python_stages(base='build_base', input_args=args, output_stages=stages)
+
     # Building blocks are chunks of container-builder instructions that can be
     # copied to any build stage with the addition operator.
     building_blocks = collections.OrderedDict()
-    building_blocks['base_packages'] = hpccm.building_blocks.packages(
-        ospackages=_common_packages)
 
-    # These are the most expensive and most reusable layers, so we put them first.
-    building_blocks['compiler'] = get_compiler(args, compiler_build_stage=stages.get('compiler_build'))
-    building_blocks['gdrcopy'] = get_gdrcopy(args, building_blocks['compiler'])
-    building_blocks['ucx'] = get_ucx(args, building_blocks['compiler'], building_blocks['gdrcopy'])
-    building_blocks['mpi'] = get_mpi(args, building_blocks['compiler'], building_blocks['ucx'])
     for i, cmake in enumerate(args.cmake):
         building_blocks['cmake' + str(i)] = hpccm.building_blocks.cmake(
             eula=True,
@@ -744,7 +759,9 @@ def build_stages(args) -> typing.Iterable[hpccm.Stage]:
     # Create the stage from which the targeted image will be tagged.
     stages['main'] = hpccm.Stage()
 
-    stages['main'] += hpccm.primitives.baseimage(image=base_image_tag(args))
+    stages['main'] += hpccm.primitives.baseimage(image='build_base',
+                                                 _distro=hpccm_distro_name(args),
+                                                 _as='main')
     for bb in building_blocks.values():
         if bb is not None:
             stages['main'] += bb