TediPapajorgji commented on issue #9492:
URL: https://github.com/apache/airflow/issues/9492#issuecomment-1058377047


   This may be helpful, I monkey patched the DockerOperator to support GPU's in 
my dag, code below. I'm using package version `2.4.1` of 
[apache-airflow-providers-docker](https://pypi.org/project/apache-airflow-providers-docker/)
   
   The key difference from the following monkey patch is the addition of (you 
can find it in the code below):
   ```
   device_requests=[
       docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) # This part 
here is what we added to support gpu's
   ]
   ```
   
   I placed this code inside my DAG python file right before i declare the dag 
with `with DAG(....) as dag:`
   
   ```python
   from airflow.operators.docker_operator import DockerOperator
   
   # This is ripped from 
https://github.com/apache/airflow/blob/5b45a78dca284e504280964951f079fca1866226/airflow/providers/docker/operators/docker.py#L38
   # to allow the monkey patch below to operater
   def stringify(line: Union[str, bytes]):
       """Make sure string is returned even if bytes are passed. Docker stream 
can return bytes."""
       decode_method = getattr(line, 'decode', None)
       if decode_method:
           return decode_method(encoding='utf-8', errors='surrogateescape')
       else:
           return line
           
   # This is a copy from 
https://github.com/apache/airflow/blob/5b45a78dca284e504280964951f079fca1866226/airflow/providers/docker/operators/docker.py#L257
   # with a slight modification to support GPU instances
   def new_run_image_with_mounts(
           self, target_mounts, add_tmp_variable: bool
   ) -> Optional[Union[List[str], str]]:
       if add_tmp_variable:
           self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir
       else:
           self.environment.pop('AIRFLOW_TMP_DIR', None)
       if not self.cli:
           raise Exception("The 'cli' should be initialized before!")
       self.container = self.cli.create_container(
           command=self.format_command(self.command),
           name=self.container_name,
           environment={**self.environment, **self._private_environment},
           host_config=self.cli.create_host_config(
               auto_remove=False,
               mounts=target_mounts,
               network_mode=self.network_mode,
               shm_size=self.shm_size,
               dns=self.dns,
               dns_search=self.dns_search,
               cpu_shares=int(round(self.cpus * 1024)),
               mem_limit=self.mem_limit,
               cap_add=self.cap_add,
               extra_hosts=self.extra_hosts,
               privileged=self.privileged,
               device_requests=[
                   docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) 
# This part here is what we added to support gpu's
               ],
           ),
           image=self.image,
           user=self.user,
           entrypoint=self.format_command(self.entrypoint),
           working_dir=self.working_dir,
           tty=self.tty,
       )
       logstream = self.cli.attach(container=self.container['Id'], stdout=True, 
stderr=True, stream=True)
       try:
           self.cli.start(self.container['Id'])
   
           log_lines = []
           for log_chunk in logstream:
               log_chunk = stringify(log_chunk).strip()
               log_lines.append(log_chunk)
               self.log.info("%s", log_chunk)
   
           result = self.cli.wait(self.container['Id'])
           if result['StatusCode'] != 0:
               joined_log_lines = "\n".join(log_lines)
               raise AirflowException(f'Docker container failed: {repr(result)} 
lines {joined_log_lines}')
   
           if self.retrieve_output:
               return self._attempt_to_retrieve_result()
           elif self.do_xcom_push:
               log_parameters = {
                   'container': self.container['Id'],
                   'stdout': True,
                   'stderr': True,
                   'stream': True,
               }
               try:
                   if self.xcom_all:
                       return [stringify(line).strip() for line in 
self.cli.logs(**log_parameters)]
                   else:
                       lines = [stringify(line).strip() for line in 
self.cli.logs(**log_parameters, tail=1)]
                       return lines[-1] if lines else None
               except StopIteration:
                   # handle the case when there is not a single line to iterate 
on
                   return None
           return None
       finally:
           if self.auto_remove:
               self.cli.remove_container(self.container['Id'])
   
   # Monkey patch our modified function to the DockerOperator
   DockerOperator._run_image_with_mounts = new_run_image_with_mounts
   ```
   
   After monkey patching it i just use the docker operator as per the usual
   
   ```
   train = DockerOperator(
           task_id='task-id-here',
           image='image here',
           api_version='auto',
           auto_remove=True,
           command='some command here',
           docker_url="unix://var/run/docker.sock",
           network_mode="bridge"
       )
   ```
   
   
   You can customize this further and make the`device_requests` array a 
parameter to the function and then when initializing the DockerOperator pass 
whatever you need to it - but for the sake of this example i didn't do that. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to