TediPapajorgji commented on issue #9492: URL: https://github.com/apache/airflow/issues/9492#issuecomment-1058377047
This may be helpful, I monkey patched the DockerOperator to support GPU's in my dag, code below. I'm using package version `2.4.1` of [apache-airflow-providers-docker](https://pypi.org/project/apache-airflow-providers-docker/) The key difference from the following monkey patch is the addition of (you can find it in the code below): ``` device_requests=[ docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) # This part here is what we added to support gpu's ] ``` I placed this code inside my DAG python file right before i declare the dag with `with DAG(....) as dag:` ```python from airflow.operators.docker_operator import DockerOperator # This is ripped from https://github.com/apache/airflow/blob/5b45a78dca284e504280964951f079fca1866226/airflow/providers/docker/operators/docker.py#L38 # to allow the monkey patch below to operater def stringify(line: Union[str, bytes]): """Make sure string is returned even if bytes are passed. Docker stream can return bytes.""" decode_method = getattr(line, 'decode', None) if decode_method: return decode_method(encoding='utf-8', errors='surrogateescape') else: return line # This is a copy from https://github.com/apache/airflow/blob/5b45a78dca284e504280964951f079fca1866226/airflow/providers/docker/operators/docker.py#L257 # with a slight modification to support GPU instances def new_run_image_with_mounts( self, target_mounts, add_tmp_variable: bool ) -> Optional[Union[List[str], str]]: if add_tmp_variable: self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir else: self.environment.pop('AIRFLOW_TMP_DIR', None) if not self.cli: raise Exception("The 'cli' should be initialized before!") self.container = self.cli.create_container( command=self.format_command(self.command), name=self.container_name, environment={**self.environment, **self._private_environment}, host_config=self.cli.create_host_config( auto_remove=False, mounts=target_mounts, network_mode=self.network_mode, shm_size=self.shm_size, dns=self.dns, dns_search=self.dns_search, cpu_shares=int(round(self.cpus * 1024)), mem_limit=self.mem_limit, cap_add=self.cap_add, extra_hosts=self.extra_hosts, privileged=self.privileged, device_requests=[ docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) # This part here is what we added to support gpu's ], ), image=self.image, user=self.user, entrypoint=self.format_command(self.entrypoint), working_dir=self.working_dir, tty=self.tty, ) logstream = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True) try: self.cli.start(self.container['Id']) log_lines = [] for log_chunk in logstream: log_chunk = stringify(log_chunk).strip() log_lines.append(log_chunk) self.log.info("%s", log_chunk) result = self.cli.wait(self.container['Id']) if result['StatusCode'] != 0: joined_log_lines = "\n".join(log_lines) raise AirflowException(f'Docker container failed: {repr(result)} lines {joined_log_lines}') if self.retrieve_output: return self._attempt_to_retrieve_result() elif self.do_xcom_push: log_parameters = { 'container': self.container['Id'], 'stdout': True, 'stderr': True, 'stream': True, } try: if self.xcom_all: return [stringify(line).strip() for line in self.cli.logs(**log_parameters)] else: lines = [stringify(line).strip() for line in self.cli.logs(**log_parameters, tail=1)] return lines[-1] if lines else None except StopIteration: # handle the case when there is not a single line to iterate on return None return None finally: if self.auto_remove: self.cli.remove_container(self.container['Id']) # Monkey patch our modified function to the DockerOperator DockerOperator._run_image_with_mounts = new_run_image_with_mounts ``` After monkey patching it i just use the docker operator as per the usual ``` train = DockerOperator( task_id='task-id-here', image='image here', api_version='auto', auto_remove=True, command='some command here', docker_url="unix://var/run/docker.sock", network_mode="bridge" ) ``` You can customize this further and make the`device_requests` array a parameter to the function and then when initializing the DockerOperator pass whatever you need to it - but for the sake of this example i didn't do that. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
