This is an automated email from the ASF dual-hosted git repository. fgerlits pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 5c058ce703f7041842016a4e207604daf60565c8 Author: Martin Zink <[email protected]> AuthorDate: Wed Apr 17 10:48:11 2024 +0200 MINIFICPP-2349 FetchModbusTcp Signed-off-by: Ferenc Gerlits <[email protected]> This closes #1779 --- .gitignore | 1 + PROCESSORS.md | 30 +++ docker/requirements.txt | 2 +- docker/test/integration/cluster/ContainerStore.py | 9 + .../test/integration/cluster/DockerTestCluster.py | 5 + docker/test/integration/cluster/ImageStore.py | 5 + .../integration/cluster/checkers/ModbusChecker.py | 24 ++ .../integration/cluster/containers/DiagSlave.py | 36 +++ .../features/MiNiFi_integration_test_driver.py | 3 + .../integration/features/fetch_modbus_tcp.feature | 43 ++++ docker/test/integration/features/steps/steps.py | 23 ++ .../minifi/controllers/JsonRecordSetWriter.py | 24 ++ .../minifi/processors/FetchModbusTcp.py | 26 ++ .../validators/SingleJSONFileOutputValidator.py | 2 + .../integration/resources/diagslave/Dockerfile | 5 + extensions/standard-processors/CMakeLists.txt | 2 +- .../controllers/JsonRecordSetReader.cpp | 2 +- .../standard-processors/modbus/ByteConverters.h | 50 ++++ extensions/standard-processors/modbus/Error.h | 77 ++++++ .../standard-processors/modbus/FetchModbusTcp.cpp | 275 +++++++++++++++++++++ .../standard-processors/modbus/FetchModbusTcp.h | 143 +++++++++++ .../modbus/ReadModbusFunctions.cpp | 238 ++++++++++++++++++ .../modbus/ReadModbusFunctions.h | 172 +++++++++++++ .../standard-processors/processors/PutTCP.cpp | 228 +++-------------- extensions/standard-processors/processors/PutTCP.h | 27 +- .../standard-processors/tests/CMakeLists.txt | 4 +- .../tests/unit/JsonRecordTests.cpp | 26 +- .../tests/unit/modbus/ModbusTests.cpp | 259 +++++++++++++++++++ libminifi/include/core/Record.h | 11 +- libminifi/include/core/RecordField.h | 24 +- libminifi/include/utils/StringUtils.h | 17 ++ libminifi/include/utils/net/AsioSocketUtils.h | 8 +- libminifi/include/utils/net/ConnectionHandler.h | 173 +++++++++++++ .../include/utils/net/ConnectionHandlerBase.h | 41 +++ libminifi/test/libtest/unit/Catch.h | 2 +- libminifi/test/libtest/unit/TestRecord.h | 51 ++-- 36 files changed, 1819 insertions(+), 249 deletions(-) diff --git a/.gitignore b/.gitignore index c8b0fc41b..12a6c15a2 100644 --- a/.gitignore +++ b/.gitignore @@ -65,6 +65,7 @@ __pycache__/ /provenance_repository /logs msi/WixWin.wsi +docker/behavex_output .vs/** *.swp diff --git a/PROCESSORS.md b/PROCESSORS.md index 735e7e8a9..97ade1f37 100644 --- a/PROCESSORS.md +++ b/PROCESSORS.md @@ -42,6 +42,7 @@ limitations under the License. - [FetchAzureDataLakeStorage](#FetchAzureDataLakeStorage) - [FetchFile](#FetchFile) - [FetchGCSObject](#FetchGCSObject) +- [FetchModbusTcp](#FetchModbusTcp) - [FetchOPCProcessor](#FetchOPCProcessor) - [FetchS3Object](#FetchS3Object) - [FetchSFTP](#FetchSFTP) @@ -906,6 +907,35 @@ In the list below, the names of required properties appear in bold. Any other pr | gcs.error.domain | failure | The domain of the error occurred during operation. | +## FetchModbusTcp + +### Description + +Processor able to read data from industrial PLCs using Modbus TCP/IP + +### Properties + +In the list below, the names of required properties appear in bold. Any other properties (not in bold) are considered optional. The table also indicates any default values, and whether a property supports the NiFi Expression Language. + +| Name | Default Value | Allowable Values | Description | +|--------------------------------|---------------|------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Hostname** | | | The ip address or hostname of the destination.<br/>**Supports Expression Language: true** | +| **Port** | 502 | | The port or service on the destination.<br/>**Supports Expression Language: true** | +| **Unit Identifier** | 0 | | The port or service on the destination.<br/>**Supports Expression Language: true** | +| **Idle Connection Expiration** | 15 seconds | | The amount of time a connection should be held open without being used before closing the connection. A value of 0 seconds will disable this feature. | +| **Connection Per FlowFile** | false | true<br/>false | Specifies whether to send each FlowFile's content on an individual connection. | +| **Timeout** | 15 seconds | | The timeout for connecting to and communicating with the destination. | +| SSL Context Service | | | The Controller Service to use in order to obtain an SSL Context. If this property is set, messages will be sent over a secure connection. | +| **Record Set Writer** | | | Specifies the Controller Service to use for writing results to a FlowFile. | + +### Relationships + +| Name | Description | +|---------|------------------------------| +| success | Successfully processed | +| failure | An error occurred processing | + + ## FetchOPCProcessor ### Description diff --git a/docker/requirements.txt b/docker/requirements.txt index e6b029eed..294f836ff 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -5,7 +5,7 @@ docker==5.0.0 kafka-python==2.0.2 confluent-kafka==1.7.0 PyYAML==6.0.1 -m2crypto==0.38.0 +m2crypto==0.41.0 watchdog==2.1.2 pyopenssl==23.0.0 azure-storage-blob==12.13.0 diff --git a/docker/test/integration/cluster/ContainerStore.py b/docker/test/integration/cluster/ContainerStore.py index 2f051bf57..f48646158 100644 --- a/docker/test/integration/cluster/ContainerStore.py +++ b/docker/test/integration/cluster/ContainerStore.py @@ -38,6 +38,7 @@ from .containers.MinifiC2ServerContainer import MinifiC2ServerContainer from .containers.GrafanaLokiContainer import GrafanaLokiContainer from .containers.GrafanaLokiContainer import GrafanaLokiOptions from .containers.ReverseProxyContainer import ReverseProxyContainer +from .containers.DiagSlave import DiagSlave from .FeatureContext import FeatureContext @@ -290,6 +291,14 @@ class ContainerStore: network=self.network, image_store=self.image_store, command=command)) + elif engine == "diag-slave-tcp": + return self.containers.setdefault(container_name, + DiagSlave(feature_context=feature_context, + name=container_name, + vols=self.vols, + network=self.network, + image_store=self.image_store, + command=command)) else: raise Exception('invalid flow engine: \'%s\'' % engine) diff --git a/docker/test/integration/cluster/DockerTestCluster.py b/docker/test/integration/cluster/DockerTestCluster.py index 581f439cf..430ec52f9 100644 --- a/docker/test/integration/cluster/DockerTestCluster.py +++ b/docker/test/integration/cluster/DockerTestCluster.py @@ -34,6 +34,7 @@ from .checkers.PostgresChecker import PostgresChecker from .checkers.PrometheusChecker import PrometheusChecker from .checkers.SplunkChecker import SplunkChecker from .checkers.GrafanaLokiChecker import GrafanaLokiChecker +from .checkers.ModbusChecker import ModbusChecker from utils import get_peak_memory_usage, get_minifi_pid, get_memory_usage, retry_check @@ -52,6 +53,7 @@ class DockerTestCluster: self.prometheus_checker = PrometheusChecker() self.grafana_loki_checker = GrafanaLokiChecker() self.minifi_controller_executor = MinifiControllerExecutor(self.container_communicator) + self.modbus_checker = ModbusChecker(self.container_communicator) def cleanup(self): self.container_store.cleanup() @@ -413,3 +415,6 @@ class DockerTestCluster: def wait_for_lines_on_grafana_loki(self, lines: List[str], timeout_seconds: int, ssl: bool, tenant_id: str): return self.grafana_loki_checker.wait_for_lines_on_grafana_loki(lines, timeout_seconds, ssl, tenant_id) + + def set_value_on_plc_with_modbus(self, container_name, modbus_cmd): + return self.modbus_checker.set_value_on_plc_with_modbus(container_name, modbus_cmd) diff --git a/docker/test/integration/cluster/ImageStore.py b/docker/test/integration/cluster/ImageStore.py index f82fcb792..03a5b0b94 100644 --- a/docker/test/integration/cluster/ImageStore.py +++ b/docker/test/integration/cluster/ImageStore.py @@ -66,6 +66,8 @@ class ImageStore: image = self.__build_splunk_image() elif container_engine == "reverse-proxy": image = self.__build_reverse_proxy_image() + elif container_engine == "diag-slave-tcp": + image = self.__build_diagslave_image() else: raise Exception("There is no associated image for " + container_engine) @@ -231,6 +233,9 @@ class ImageStore: def __build_reverse_proxy_image(self): return self.__build_image_by_path(self.test_dir + "/resources/reverse-proxy", 'reverse-proxy') + def __build_diagslave_image(self): + return self.__build_image_by_path(self.test_dir + "/resources/diagslave", 'diag-slave-tcp') + def __build_image(self, dockerfile, context_files=[]): conf_dockerfile_buffer = BytesIO() docker_context_buffer = BytesIO() diff --git a/docker/test/integration/cluster/checkers/ModbusChecker.py b/docker/test/integration/cluster/checkers/ModbusChecker.py new file mode 100644 index 000000000..9ef26a6fe --- /dev/null +++ b/docker/test/integration/cluster/checkers/ModbusChecker.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ModbusChecker: + def __init__(self, container_communicator): + self.container_communicator = container_communicator + + def set_value_on_plc_with_modbus(self, container_name, modbus_cmd): + print(modbus_cmd) + (code, output) = self.container_communicator.execute_command(container_name, ["modbus", "localhost", modbus_cmd]) + print(output) + return code == 0 diff --git a/docker/test/integration/cluster/containers/DiagSlave.py b/docker/test/integration/cluster/containers/DiagSlave.py new file mode 100644 index 000000000..25c70f7b2 --- /dev/null +++ b/docker/test/integration/cluster/containers/DiagSlave.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from .Container import Container + + +class DiagSlave(Container): + def __init__(self, feature_context, name, vols, network, image_store, command=None): + super().__init__(feature_context, name, 'diag-slave-tcp', vols, network, image_store, command) + + def get_startup_finished_log_entry(self): + return "Server started up successfully." + + def deploy(self): + if not self.set_deployed(): + return + + logging.info('Creating and running a DiagSlave docker container...') + self.client.containers.run( + self.image_store.get_image(self.get_engine()), + detach=True, + name=self.name, + network=self.network.name) + logging.info('Added container \'%s\'', self.name) diff --git a/docker/test/integration/features/MiNiFi_integration_test_driver.py b/docker/test/integration/features/MiNiFi_integration_test_driver.py index 34acc0096..6ad052294 100644 --- a/docker/test/integration/features/MiNiFi_integration_test_driver.py +++ b/docker/test/integration/features/MiNiFi_integration_test_driver.py @@ -479,3 +479,6 @@ class MiNiFi_integration_test: def check_lines_on_grafana_loki(self, lines: List[str], timeout_seconds: int, ssl: bool, tenant_id=None): assert self.cluster.wait_for_lines_on_grafana_loki(lines, timeout_seconds, ssl, tenant_id) or self.cluster.log_app_output() + + def set_value_on_plc_with_modbus(self, container_name, modbus_cmd): + assert self.cluster.set_value_on_plc_with_modbus(container_name, modbus_cmd) diff --git a/docker/test/integration/features/fetch_modbus_tcp.feature b/docker/test/integration/features/fetch_modbus_tcp.feature new file mode 100644 index 000000000..6c1eab84e --- /dev/null +++ b/docker/test/integration/features/fetch_modbus_tcp.feature @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@MODBUS +Feature: Minifi C++ can act as a modbus tcp master + + Background: + Given the content of "/tmp/output" is monitored + + Scenario: MiNiFi can fetch data from a modbus slave + Given a FetchModbusTcp processor + And a JsonRecordSetWriter controller service is set up for FetchModbusTcp + And a PutFile processor with the "Directory" property set to "/tmp/output" + And the "Unit identifier" property of the FetchModbusTcp processor is set to "255" + And there is an accessible PLC with modbus enabled + And PLC register has been set with h@52=123 command + And PLC register has been set with h@5678/f=1.75 command + And PLC register has been set with h@4444=77 command + And PLC register has been set with h@4445=105 command + And PLC register has been set with h@4446=78 command + And PLC register has been set with h@4447=105 command + And PLC register has been set with h@4448=70 command + And PLC register has been set with h@4449=105 command + + And the "success" relationship of the FetchModbusTcp processor is connected to the PutFile + And the "foo" property of the FetchModbusTcp processor is set to "holding-register:52" + And the "bar" property of the FetchModbusTcp processor is set to "405678:REAL" + And the "baz" property of the FetchModbusTcp processor is set to "4x4444:CHAR[6]" + + When both instances start up + Then a flowfile with the JSON content "{"foo":123,"bar":1.75,"baz":["M", "i", "N", "i", "F", "i"]}" is placed in the monitored directory in less than 10 seconds diff --git a/docker/test/integration/features/steps/steps.py b/docker/test/integration/features/steps/steps.py index b26cd1c61..75cf19a34 100644 --- a/docker/test/integration/features/steps/steps.py +++ b/docker/test/integration/features/steps/steps.py @@ -24,6 +24,7 @@ from minifi.controllers.GCPCredentialsControllerService import GCPCredentialsCon from minifi.controllers.ElasticsearchCredentialsService import ElasticsearchCredentialsService from minifi.controllers.ODBCService import ODBCService from minifi.controllers.KubernetesControllerService import KubernetesControllerService +from minifi.controllers.JsonRecordSetWriter import JsonRecordSetWriter from behave import given, then, when from behave.model_describe import ModelDescriptor @@ -395,6 +396,16 @@ def step_impl(context, processor_name): processor.set_property('SSL Context Service', ssl_context_service.name) +# RecordSetWriters +@given("a JsonRecordSetWriter controller service is set up for {processor_name}") +def step_impl(context, processor_name): + json_record_set_writer = JsonRecordSetWriter() + + processor = context.test.get_node_by_name(processor_name) + processor.controller_services.append(json_record_set_writer) + processor.set_property('Record Set Writer', json_record_set_writer.name) + + # Kubernetes def __set_up_the_kubernetes_controller_service(context, processor_name, service_property_name, properties): kubernetes_controller_service = KubernetesControllerService("Kubernetes Controller Service", properties) @@ -1305,3 +1316,15 @@ def step_impl(context, parameter_context_name, parameter_name, parameter_value): def step_impl(context, parameter_context_name): container = context.test.acquire_container(context=context, name='minifi-cpp-flow', engine='minifi-cpp') container.set_parameter_context_name(parameter_context_name) + + +# Modbus +@given(u'there is an accessible PLC with modbus enabled') +def step_impl(context): + context.test.acquire_container(context=context, name="diag-slave-tcp", engine="diag-slave-tcp") + context.test.start('diag-slave-tcp') + + +@given(u'PLC register has been set with {modbus_cmd} command') +def step_impl(context, modbus_cmd): + context.test.set_value_on_plc_with_modbus(context.test.get_container_name_with_postfix('diag-slave-tcp'), modbus_cmd) diff --git a/docker/test/integration/minifi/controllers/JsonRecordSetWriter.py b/docker/test/integration/minifi/controllers/JsonRecordSetWriter.py new file mode 100644 index 000000000..cfb15cfd1 --- /dev/null +++ b/docker/test/integration/minifi/controllers/JsonRecordSetWriter.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..core.ControllerService import ControllerService + + +class JsonRecordSetWriter(ControllerService): + def __init__(self, name=None, cert=None, key=None, ca_cert=None, passphrase=None, use_system_cert_store=None): + super(JsonRecordSetWriter, self).__init__(name=name) + self.service_class = 'JsonRecordSetWriter' + self.properties['Output Grouping'] = 'OneLinePerObject' diff --git a/docker/test/integration/minifi/processors/FetchModbusTcp.py b/docker/test/integration/minifi/processors/FetchModbusTcp.py new file mode 100644 index 000000000..0141817cf --- /dev/null +++ b/docker/test/integration/minifi/processors/FetchModbusTcp.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..core.Processor import Processor + + +class FetchModbusTcp(Processor): + def __init__(self, context): + super(FetchModbusTcp, self).__init__( + context=context, + clazz='FetchModbusTcp', + properties={ + 'Hostname': f'diag-slave-tcp-{context.feature_id}', + }, + auto_terminate=["success", "failure"]) diff --git a/docker/test/integration/minifi/validators/SingleJSONFileOutputValidator.py b/docker/test/integration/minifi/validators/SingleJSONFileOutputValidator.py index 438a0d888..7c8be2dd0 100644 --- a/docker/test/integration/minifi/validators/SingleJSONFileOutputValidator.py +++ b/docker/test/integration/minifi/validators/SingleJSONFileOutputValidator.py @@ -40,6 +40,8 @@ class SingleJSONFileOutputValidator(FileOutputValidator): continue with open(full_path, 'r') as out_file: file_json_content = json.loads(out_file.read()) + if file_json_content != expected_json_content: + print(f"JSON doesnt match actual: {file_json_content}, expected: {expected_json_content}") return file_json_content == expected_json_content return False diff --git a/docker/test/integration/resources/diagslave/Dockerfile b/docker/test/integration/resources/diagslave/Dockerfile new file mode 100644 index 000000000..4b7d29ba2 --- /dev/null +++ b/docker/test/integration/resources/diagslave/Dockerfile @@ -0,0 +1,5 @@ +FROM panterdsd/diagslave:latest +RUN pip install modbus-cli + +ENV PROTOCOL=tcp + diff --git a/extensions/standard-processors/CMakeLists.txt b/extensions/standard-processors/CMakeLists.txt index 630fb3cab..f62003e75 100644 --- a/extensions/standard-processors/CMakeLists.txt +++ b/extensions/standard-processors/CMakeLists.txt @@ -20,7 +20,7 @@ include(${CMAKE_SOURCE_DIR}/extensions/ExtensionHeader.txt) -file(GLOB SOURCES "processors/*.cpp" "controllers/*.cpp" "utils/*.cpp") +file(GLOB SOURCES "processors/*.cpp" "controllers/*.cpp" "utils/*.cpp" "modbus/*.cpp") add_minifi_library(minifi-standard-processors SHARED ${SOURCES}) target_include_directories(minifi-standard-processors PUBLIC "${CMAKE_SOURCE_DIR}/extensions/standard-processors") diff --git a/extensions/standard-processors/controllers/JsonRecordSetReader.cpp b/extensions/standard-processors/controllers/JsonRecordSetReader.cpp index 1214bd1eb..19435a357 100644 --- a/extensions/standard-processors/controllers/JsonRecordSetReader.cpp +++ b/extensions/standard-processors/controllers/JsonRecordSetReader.cpp @@ -38,7 +38,7 @@ nonstd::expected<core::RecordField, std::error_code> parse(const rapidjson::Valu return core::RecordField{json_value.GetInt64()}; } if (json_value.IsString()) { - return core::RecordField{json_value.GetString()}; + return core::RecordField{std::string{json_value.GetString(), json_value.GetStringLength()}}; } if (json_value.IsArray()) { core::RecordArray record_array; diff --git a/extensions/standard-processors/modbus/ByteConverters.h b/extensions/standard-processors/modbus/ByteConverters.h new file mode 100644 index 000000000..cc5d38e17 --- /dev/null +++ b/extensions/standard-processors/modbus/ByteConverters.h @@ -0,0 +1,50 @@ +/** +* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <array> +#include <cstdint> +#include <span> +#include <bit> + +namespace org::apache::nifi::minifi::modbus { + +template<typename T, std::endian to_endianness = std::endian::big> +constexpr std::array<std::byte, std::max(sizeof(T), sizeof(uint16_t))> toBytes(T value) { + std::array<std::byte, std::max(sizeof(T), sizeof(uint16_t))> buffer{}; + + std::copy_n(reinterpret_cast<std::byte*>(&value), sizeof(T), buffer.begin()); + + if constexpr (std::endian::native != to_endianness) { + std::reverse(buffer.begin(), buffer.end()); + } + + return buffer; +} + +template<typename T, std::endian from_endianness = std::endian::big> +constexpr T fromBytes(std::array<std::byte, std::max(sizeof(T), sizeof(uint16_t))> bytes) { + if constexpr (std::endian::native != from_endianness) { + std::reverse(bytes.begin(), bytes.end()); + } + + T result; + std::memcpy(&result, bytes.data(), sizeof(T)); + return result; +} +} // namespace org::apache::nifi::minifi::modbus diff --git a/extensions/standard-processors/modbus/Error.h b/extensions/standard-processors/modbus/Error.h new file mode 100644 index 000000000..6ba3714b6 --- /dev/null +++ b/extensions/standard-processors/modbus/Error.h @@ -0,0 +1,77 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> +#include <system_error> +#include "magic_enum.hpp" + +namespace org::apache::nifi::minifi::modbus { + +// Based from https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf MODBUS Exception Codes +enum class ModbusExceptionCode : std::underlying_type_t<std::byte> { + IllegalFunction = 0x01, + IllegalDataAddress = 0x02, + IllegalDataValue = 0x03, + SlaveDeviceFailure = 0x04, + Acknowledge = 0x05, + SlaveDeviceBusy = 0x06, + NegativeAcknowledge = 0x07, + MemoryParityError = 0x08, + GatewayPathUnavailable = 0x0a, + GatewayTargetDeviceFailedToRespond = 0x0b, + InvalidResponse, + MessageTooLarge, + InvalidTransactionId, + IllegalProtocol, + InvalidSlaveId, + MessageTooShort, + UnexpectedResponseFunctionCode, + UnexpectedResponsePDUSize +}; + + +struct ModbusErrorCategory final : std::error_category { + [[nodiscard]] const char* name() const noexcept override { + return "modbus error"; + } + + [[nodiscard]] std::string message(int ev) const override { + const auto modbus_exception_code = static_cast<ModbusExceptionCode>(ev); + auto modbus_exception_code_str = std::string{magic_enum::enum_name<ModbusExceptionCode>(modbus_exception_code)}; + if (modbus_exception_code_str.empty()) { + return "UNKNOWN ERROR"; + } + return modbus_exception_code_str; + } +}; + +inline const ModbusErrorCategory& modbus_category() noexcept { + static ModbusErrorCategory category; + return category; +}; + +inline std::error_code make_error_code(ModbusExceptionCode c) { + return {static_cast<int>(c), modbus_category()}; +} + +} // namespace org::apache::nifi::minifi::modbus + +template <> +struct std::is_error_code_enum<org::apache::nifi::minifi::modbus::ModbusExceptionCode> : std::true_type {}; diff --git a/extensions/standard-processors/modbus/FetchModbusTcp.cpp b/extensions/standard-processors/modbus/FetchModbusTcp.cpp new file mode 100644 index 000000000..1f8fc0ff7 --- /dev/null +++ b/extensions/standard-processors/modbus/FetchModbusTcp.cpp @@ -0,0 +1,275 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FetchModbusTcp.h" + +#include <utils/net/ConnectionHandler.h> + +#include "core/Resource.h" + +#include "core/ProcessSession.h" +#include "modbus/Error.h" +#include "modbus/ReadModbusFunctions.h" +#include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" + +using namespace std::literals::chrono_literals; + +namespace org::apache::nifi::minifi::modbus { + + +void FetchModbusTcp::onSchedule(core::ProcessContext& context, core::ProcessSessionFactory&) { + const auto record_set_writer_name = context.getProperty(RecordSetWriter); + record_set_writer_ = std::dynamic_pointer_cast<core::RecordSetWriter>(context.getControllerService(record_set_writer_name.value_or(""))); + if (!record_set_writer_) { + throw Exception{ExceptionType::PROCESS_SCHEDULE_EXCEPTION, "Invalid or missing RecordSetWriter"}; + } + + // if the required properties are missing or empty even before evaluating the EL expression, then we can throw in onSchedule, before we waste any flow files + if (context.getProperty(Hostname).value_or(std::string{}).empty()) { + throw Exception{ExceptionType::PROCESS_SCHEDULE_EXCEPTION, "missing hostname"}; + } + if (context.getProperty(Port).value_or(std::string{}).empty()) { + throw Exception{ExceptionType::PROCESS_SCHEDULE_EXCEPTION, "missing port"}; + } + if (const auto idle_connection_expiration = context.getProperty<core::TimePeriodValue>(IdleConnectionExpiration); idle_connection_expiration && idle_connection_expiration->getMilliseconds() > 0ms) { + idle_connection_expiration_ = idle_connection_expiration->getMilliseconds(); + } else { + idle_connection_expiration_.reset(); + } + + if (const auto timeout = context.getProperty<core::TimePeriodValue>(Timeout); timeout && timeout->getMilliseconds() > 0ms) { + timeout_duration_ = timeout->getMilliseconds(); + } else { + timeout_duration_ = 15s; + } + + if (context.getProperty<bool>(ConnectionPerFlowFile).value_or(false)) { + connections_.reset(); + } else { + connections_.emplace(); + } + + ssl_context_.reset(); + if (const auto controller_service_name = context.getProperty(SSLContextService); controller_service_name && !IsNullOrEmpty(*controller_service_name)) { + if (auto controller_service = context.getControllerService(*controller_service_name)) { + if (const auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(controller_service)) { + ssl_context_ = utils::net::getSslContext(*ssl_context_service); + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, *controller_service_name + " is not an SSL Context Service"); + } + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, "Invalid controller service: " + *controller_service_name); + } + } + + readDynamicPropertyKeys(context); +} + +void FetchModbusTcp::onTrigger(core::ProcessContext& context, core::ProcessSession& session) { + const auto flow_file = getOrCreateFlowFile(session); + if (!flow_file) { + logger_->log_error("No flowfile to work on"); + return; + } + + removeExpiredConnections(); + + auto hostname = context.getProperty(Hostname, flow_file.get()).value_or(std::string{}); + auto port = context.getProperty(Port, flow_file.get()).value_or(std::string{}); + + if (hostname.empty() || port.empty()) { + logger_->log_error("[{}] invalid target endpoint: hostname: {}, port: {}", flow_file->getUUIDStr(), + hostname.empty() ? "(empty)" : hostname.c_str(), + port.empty() ? "(empty)" : port.c_str()); + session.transfer(flow_file, Failure); + return; + } + + auto connection_id = utils::net::ConnectionId(std::move(hostname), std::move(port)); + std::shared_ptr<utils::net::ConnectionHandlerBase> handler; + if (!connections_ || !connections_->contains(connection_id)) { + if (ssl_context_) { + handler = std::make_shared<utils::net::ConnectionHandler<utils::net::SslSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, &*ssl_context_); + } else { + handler = std::make_shared<utils::net::ConnectionHandler<utils::net::TcpSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, nullptr); + } + if (connections_) { + (*connections_)[connection_id] = handler; + } + } else { + handler = (*connections_)[connection_id]; + } + + gsl_Expects(handler); + + processFlowFile(handler, context, session, flow_file); +} + +void FetchModbusTcp::initialize() { + setSupportedProperties(Properties); + setSupportedRelationships(Relationships); +} + +void FetchModbusTcp::readDynamicPropertyKeys(const core::ProcessContext& context) { + dynamic_property_keys_.clear(); + const std::vector<std::string> dynamic_prop_keys = context.getDynamicPropertyKeys(); + for (const auto& key : dynamic_prop_keys) { + dynamic_property_keys_.emplace_back(core::PropertyDefinitionBuilder<>::createProperty(key).withDescription("auto generated").supportsExpressionLanguage(true).build()); + } +} + +std::shared_ptr<core::FlowFile> FetchModbusTcp::getOrCreateFlowFile(core::ProcessSession& session) const { + if (hasIncomingConnections()) { + return session.get(); + } + return session.create(); +} + +std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>> FetchModbusTcp::getAddressMap(core::ProcessContext& context, const core::FlowFile& flow_file) { + std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>> address_map{}; + const auto unit_id_str = context.getProperty(UnitIdentifier, &flow_file).value_or("1"); + const uint8_t unit_id = utils::string::parseNumber<uint8_t>(unit_id_str) | utils::valueOrElse([this](const utils::string::ParseError&) { + logger_->log_error("Couldnt parse UnitIdentifier"); + return uint8_t{1}; + }); + for (const auto& dynamic_property : dynamic_property_keys_) { + if (std::string dynamic_property_value{}; context.getDynamicProperty(dynamic_property, dynamic_property_value, &flow_file)) { + if (auto modbus_func = ReadModbusFunction::parse(++transaction_id_, unit_id, dynamic_property_value); modbus_func) { + address_map.emplace(dynamic_property.getName(), std::move(modbus_func)); + } + } + } + return address_map; +} + +void FetchModbusTcp::removeExpiredConnections() { + if (connections_) { + std::erase_if(*connections_, [this](auto& item) -> bool { + const auto& connection_handler = item.second; + return (!connection_handler || (idle_connection_expiration_ && !connection_handler->hasBeenUsedIn(*idle_connection_expiration_))); + }); + } +} + +void FetchModbusTcp::processFlowFile(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, + core::ProcessContext& context, + core::ProcessSession& session, + const std::shared_ptr<core::FlowFile>& flow_file) { + std::unordered_map<std::string, std::string> result_map{}; + const auto address_map = getAddressMap(context, *flow_file); + if (address_map.empty()) { + logger_->log_warn("There are no registers to query"); + session.transfer(flow_file, Failure); + return; + } + + if (auto result = readModbus(connection_handler, address_map); !result) { + connection_handler->reset(); + logger_->log_error("{}", result.error().message()); + session.transfer(flow_file, Failure); + } else { + core::RecordSet record_set; + record_set.push_back(std::move(*result)); + record_set_writer_->write(record_set, flow_file, session); + session.transfer(flow_file, Success); + } +} + +nonstd::expected<core::Record, std::error_code> FetchModbusTcp::readModbus( + const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, + const std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>>& address_map) { + nonstd::expected<core::Record, std::error_code> result; + io_context_.restart(); + asio::co_spawn(io_context_, + sendRequestsAndReadResponses(*connection_handler, address_map), + [&result](const std::exception_ptr& exception_ptr, auto res) { + if (exception_ptr) { + result = nonstd::make_unexpected(ModbusExceptionCode::InvalidResponse); + } else { + result = std::move(res); + } + }); + io_context_.run(); + return result; +} + +auto FetchModbusTcp::sendRequestsAndReadResponses(utils::net::ConnectionHandlerBase& connection_handler, + const std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>>& address_map) -> asio::awaitable<nonstd::expected<core::Record, std::error_code>> { + core::Record result; + for (const auto& [variable, read_modbus_fn] : address_map) { + gsl_Expects(read_modbus_fn); + auto response = co_await sendRequestAndReadResponse(connection_handler, *read_modbus_fn); + if (!response) { + co_return nonstd::make_unexpected(response.error()); + } + result.emplace(variable, std::move(*response)); + } + co_return result; +} + + +auto FetchModbusTcp::sendRequestAndReadResponse(utils::net::ConnectionHandlerBase& connection_handler, + const ReadModbusFunction& read_modbus_function) -> asio::awaitable<nonstd::expected<core::RecordField, std::error_code>> { + std::string result; + if (auto connection_error = co_await connection_handler.setupUsableSocket(io_context_)) { // NOLINT (clang tidy doesnt like coroutines) + co_return nonstd::make_unexpected(connection_error); + } + + if (auto [write_error, bytes_written] = co_await connection_handler.write(asio::buffer(read_modbus_function.requestBytes())); write_error) { + co_return nonstd::make_unexpected(write_error); + } + + std::array<std::byte, 7> apu_buffer{}; + asio::mutable_buffer response_apu(apu_buffer.data(), 7); + if (auto [read_error, bytes_read] = co_await connection_handler.read(response_apu); read_error) { + co_return nonstd::make_unexpected(read_error); + } + + const auto received_transaction_id = fromBytes<uint16_t>({apu_buffer[0], apu_buffer[1]}); + const auto received_protocol = fromBytes<uint16_t>({apu_buffer[2], apu_buffer[3]}); + const auto received_length = fromBytes<uint16_t>({apu_buffer[4], apu_buffer[5]}); + const auto unit_id = static_cast<uint8_t>(apu_buffer[6]); + + if (received_transaction_id != read_modbus_function.getTransactionId()) { + co_return nonstd::make_unexpected(ModbusExceptionCode::InvalidTransactionId); + } + if (received_protocol != 0) { + co_return nonstd::make_unexpected(ModbusExceptionCode::IllegalProtocol); + } + if (unit_id != read_modbus_function.getUnitId()) { + co_return nonstd::make_unexpected(ModbusExceptionCode::InvalidSlaveId); + } + if (received_length + 6 > 260 || received_length <= 1) { + co_return nonstd::make_unexpected(ModbusExceptionCode::InvalidResponse); + } + + std::array<std::byte, 260-7> pdu_buffer{}; + asio::mutable_buffer response_pdu(pdu_buffer.data(), received_length-1); + auto [read_error, bytes_read] = co_await connection_handler.read(response_pdu); + if (read_error) { + co_return nonstd::make_unexpected(read_error); + } + + const auto pdu_span = std::span<std::byte>(pdu_buffer.data(), received_length-1); + co_return read_modbus_function.responseToRecordField(pdu_span); +} + +REGISTER_RESOURCE(FetchModbusTcp, Processor); + + +} // namespace org::apache::nifi::minifi::modbus diff --git a/extensions/standard-processors/modbus/FetchModbusTcp.h b/extensions/standard-processors/modbus/FetchModbusTcp.h new file mode 100644 index 000000000..b0e0a4a8d --- /dev/null +++ b/extensions/standard-processors/modbus/FetchModbusTcp.h @@ -0,0 +1,143 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "controllers/SSLContextService.h" +#include "controllers/RecordSetWriter.h" +#include "core/Processor.h" +#include "core/PropertyDefinitionBuilder.h" +#include "logging/LoggerFactory.h" +#include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" +#include "utils/net/ConnectionHandlerBase.h" + +namespace org::apache::nifi::minifi::modbus { + +class ReadModbusFunction; + +class FetchModbusTcp final : public core::Processor { + public: + explicit FetchModbusTcp(const std::string_view name, const utils::Identifier& uuid = {}) + : Processor(name, uuid) { + } + + EXTENSIONAPI static constexpr auto Description = "Processor able to read data from industrial PLCs using Modbus TCP/IP"; + + EXTENSIONAPI static constexpr auto Hostname = core::PropertyDefinitionBuilder<>::createProperty("Hostname") + .withDescription("The ip address or hostname of the destination.") + .isRequired(true) + .supportsExpressionLanguage(true) + .build(); + EXTENSIONAPI static constexpr auto Port = core::PropertyDefinitionBuilder<>::createProperty("Port") + .withDescription("The port or service on the destination.") + .withDefaultValue("502") + .isRequired(true) + .supportsExpressionLanguage(true) + .build(); + EXTENSIONAPI static constexpr auto UnitIdentifier = core::PropertyDefinitionBuilder<>::createProperty("Unit Identifier") + .withDescription("The port or service on the destination.") + .isRequired(true) + .withDefaultValue("0") + .supportsExpressionLanguage(true) + .build(); + EXTENSIONAPI static constexpr auto IdleConnectionExpiration = core::PropertyDefinitionBuilder<>::createProperty("Idle Connection Expiration") + .withDescription("The amount of time a connection should be held open without being used before closing the connection. A value of 0 seconds will disable this feature.") + .withPropertyType(core::StandardPropertyTypes::TIME_PERIOD_TYPE) + .withDefaultValue("15 seconds") + .isRequired(true) + .supportsExpressionLanguage(false) + .build(); + EXTENSIONAPI static constexpr auto ConnectionPerFlowFile = core::PropertyDefinitionBuilder<>::createProperty("Connection Per FlowFile") + .withDescription("Specifies whether to send each FlowFile's content on an individual connection.") + .withPropertyType(core::StandardPropertyTypes::BOOLEAN_TYPE) + .withDefaultValue("false") + .isRequired(true) + .supportsExpressionLanguage(false) + .build(); + EXTENSIONAPI static constexpr auto Timeout = core::PropertyDefinitionBuilder<>::createProperty("Timeout") + .withDescription("The timeout for connecting to and communicating with the destination.") + .withPropertyType(core::StandardPropertyTypes::TIME_PERIOD_TYPE) + .withDefaultValue("15 seconds") + .isRequired(true) + .supportsExpressionLanguage(false) + .build(); + EXTENSIONAPI static constexpr auto SSLContextService = core::PropertyDefinitionBuilder<>::createProperty("SSL Context Service") + .withDescription("The Controller Service to use in order to obtain an SSL Context. If this property is set, messages will be sent over a secure connection.") + .isRequired(false) + .withAllowedTypes<minifi::controllers::SSLContextService>() + .build(); + EXTENSIONAPI static constexpr auto RecordSetWriter = core::PropertyDefinitionBuilder<>::createProperty("Record Set Writer") + .withDescription("Specifies the Controller Service to use for writing results to a FlowFile. ") + .isRequired(true) + .withAllowedTypes<core::RecordSetWriter>() + .build(); + + EXTENSIONAPI static constexpr auto Properties = std::array<core::PropertyReference, 8>{ + Hostname, + Port, + UnitIdentifier, + IdleConnectionExpiration, + ConnectionPerFlowFile, + Timeout, + SSLContextService, + RecordSetWriter + }; + + EXTENSIONAPI static constexpr auto Success = core::RelationshipDefinition{"success", "Successfully processed"}; + EXTENSIONAPI static constexpr auto Failure = core::RelationshipDefinition{"failure", "An error occurred processing"}; + EXTENSIONAPI static constexpr auto Relationships = std::array{Success, Failure}; + + EXTENSIONAPI static constexpr bool SupportsDynamicProperties = true; + EXTENSIONAPI static constexpr bool SupportsDynamicRelationships = false; + EXTENSIONAPI static constexpr auto InputRequirement = core::annotation::Input::INPUT_ALLOWED; + EXTENSIONAPI static constexpr bool IsSingleThreaded = true; + + ADD_COMMON_VIRTUAL_FUNCTIONS_FOR_PROCESSORS + + void onSchedule(core::ProcessContext& context, core::ProcessSessionFactory& session_factory) override; + void onTrigger(core::ProcessContext& context, core::ProcessSession& session) override; + void initialize() override; + + private: + void readDynamicPropertyKeys(const core::ProcessContext& context); + void processFlowFile(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, + core::ProcessContext& context, + core::ProcessSession& session, + const std::shared_ptr<core::FlowFile>& flow_file); + + nonstd::expected<core::Record, std::error_code> readModbus(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, + const std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>>& address_map); + asio::awaitable<nonstd::expected<core::Record, std::error_code>> sendRequestsAndReadResponses(utils::net::ConnectionHandlerBase& connection_handler, + const std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>>& address_map); + asio::awaitable<nonstd::expected<core::RecordField, std::error_code>> sendRequestAndReadResponse(utils::net::ConnectionHandlerBase& connection_handler, + const ReadModbusFunction& read_modbus_function); + std::unordered_map<std::string, std::unique_ptr<ReadModbusFunction>> getAddressMap(core::ProcessContext& context, const core::FlowFile& flow_file); + std::shared_ptr<core::FlowFile> getOrCreateFlowFile(core::ProcessSession& session) const; + void removeExpiredConnections(); + + std::vector<core::Property> dynamic_property_keys_; + asio::io_context io_context_; + std::optional<std::unordered_map<utils::net::ConnectionId, std::shared_ptr<utils::net::ConnectionHandlerBase>>> connections_; + std::optional<std::chrono::milliseconds> idle_connection_expiration_; + std::atomic<uint16_t> transaction_id_ = 0; + std::optional<size_t> max_size_of_socket_send_buffer_; + std::chrono::milliseconds timeout_duration_ = std::chrono::seconds(15); + std::optional<asio::ssl::context> ssl_context_; + std::shared_ptr<core::logging::Logger> logger_ = core::logging::LoggerFactory<FetchModbusTcp>::getLogger(uuid_); + std::shared_ptr<core::RecordSetWriter> record_set_writer_; +}; +} // namespace org::apache::nifi::minifi::modbus diff --git a/extensions/standard-processors/modbus/ReadModbusFunctions.cpp b/extensions/standard-processors/modbus/ReadModbusFunctions.cpp new file mode 100644 index 000000000..0eaf49e47 --- /dev/null +++ b/extensions/standard-processors/modbus/ReadModbusFunctions.cpp @@ -0,0 +1,238 @@ +/** +* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ReadModbusFunctions.h" + +#include <range/v3/view/drop.hpp> + +namespace org::apache::nifi::minifi::modbus { +std::vector<std::byte> ReadModbusFunction::requestBytes() const { + constexpr std::array modbus_service_protocol_identifier = {std::byte{0}, std::byte{0}}; + const auto pdu = rawPdu(); + const auto length = gsl::narrow<uint16_t>(pdu.size() + 1); + + std::vector<std::byte> request; + ranges::copy(toBytes(transaction_id_), std::back_inserter(request)); + ranges::copy(modbus_service_protocol_identifier, std::back_inserter(request)); + ranges::copy(toBytes(length), std::back_inserter(request)); + request.push_back(std::byte{unit_id_}); + ranges::copy(pdu, std::back_inserter(request)); + return request; +} + +[[nodiscard]] auto ReadModbusFunction::getRespBytes(std::span<const std::byte> resp_pdu) const -> nonstd::expected<std::span<const std::byte>, std::error_code> { + if (resp_pdu.size() < 2) { + return nonstd::make_unexpected(ModbusExceptionCode::MessageTooShort); + } + + if (const auto resp_function_code = resp_pdu.front(); resp_function_code != getFunctionCode()) { + return nonstd::make_unexpected(ModbusExceptionCode::UnexpectedResponseFunctionCode); + } + + const auto resp_byte_count = static_cast<uint8_t>(resp_pdu[1]); + constexpr uint8_t function_code_length = 1; + constexpr uint8_t unit_id_length = 1; + const uint8_t expected_resp_pdu_size = resp_byte_count + function_code_length + unit_id_length; + if (resp_pdu.size() != expected_resp_pdu_size) { + return nonstd::make_unexpected(ModbusExceptionCode::UnexpectedResponsePDUSize); + } + + if (resp_byte_count != expectedByteCount()) { + return nonstd::make_unexpected(ModbusExceptionCode::InvalidResponse); + } + + return resp_pdu.subspan(2, resp_pdu.size()-2); +} + +[[nodiscard]] std::array<std::byte, 5> ReadCoilStatus::rawPdu() const { + std::array<std::byte, 5> result{}; + result[0] = getFunctionCode(); + + ranges::copy(toBytes(starting_address_), (result | ranges::views::drop(1) | ranges::views::take(2)).begin()); + ranges::copy(toBytes(number_of_points_), (result | ranges::views::drop(3) | ranges::views::take(2)).begin()); + return result; +} + +[[nodiscard]] nonstd::expected<core::RecordField, std::error_code> ReadCoilStatus::responseToRecordField(const std::span<const std::byte> resp_pdu) const { + const auto resp_bytes = getRespBytes(resp_pdu); + if (!resp_bytes) + return nonstd::make_unexpected(resp_bytes.error()); + + + std::vector<bool> coils{}; + for (const auto& resp_byte : *resp_bytes) { + for (uint8_t i = 0; i < 8; ++i) { + if (coils.size() == number_of_points_) { + break; + } + const bool bit_value = static_cast<bool>((resp_byte & (std::byte{1} << i)) >> i); + coils.push_back(bit_value); + } + } + if (coils.size() == 1) { + const bool val = coils.at(0); + return core::RecordField{val}; + } + core::RecordArray array; + for (bool coil : coils) { + array.emplace_back(coil); + } + return core::RecordField{std::move(array)}; +} + +[[nodiscard]] std::byte ReadCoilStatus::getFunctionCode() const { + return std::byte{0x01}; +} + +[[nodiscard]] uint8_t ReadCoilStatus::expectedByteCount() const { + return number_of_points_ / 8 + (number_of_points_ % 8 != 0); +} + +bool ReadCoilStatus::operator==(const ReadModbusFunction& rhs) const { + const auto read_coil_rhs = dynamic_cast<const ReadCoilStatus*>(&rhs); + if (!read_coil_rhs) + return false; + + return read_coil_rhs->transaction_id_ == this->transaction_id_ && + read_coil_rhs->starting_address_ == this->starting_address_ && + read_coil_rhs->number_of_points_ == this->number_of_points_; +} + +std::unique_ptr<ReadModbusFunction> ReadCoilStatus::parse(const uint16_t transaction_id, const uint8_t unit_id, const std::string_view address_str, const std::string_view length_str) { + auto start_address = utils::string::parseNumber<uint16_t>(address_str); + if (!start_address) { + return nullptr; + } + uint16_t length = length_str.empty() ? 1 : utils::string::parseNumber<uint16_t>(length_str).value_or(1); + + return std::make_unique<ReadCoilStatus>(transaction_id, unit_id, *start_address, length); +} + +namespace { +std::unique_ptr<ReadModbusFunction> parseReadRegister(const RegisterType register_type, + const uint16_t transaction_id, + const uint8_t unit_id, + const std::string_view address_str, + const std::string_view type_str, + const std::string_view length_str) { + auto start_address = utils::string::parseNumber<uint16_t>(address_str); + if (!start_address) { + return nullptr; + } + uint16_t length = length_str.empty() ? 1 : utils::string::parseNumber<uint16_t>(length_str).value_or(1); + + if (type_str.empty() || type_str == "UINT" || type_str == "WORD") { + return std::make_unique<ReadRegisters<uint16_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "BOOL") { + return std::make_unique<ReadRegisters<bool>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "SINT") { + return std::make_unique<ReadRegisters<int8_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "USINT" || type_str == "BYTE") { + return std::make_unique<ReadRegisters<uint8_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "INT") { + return std::make_unique<ReadRegisters<int16_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "DINT") { + return std::make_unique<ReadRegisters<int32_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "UDINT" || type_str == "DWORD") { + return std::make_unique<ReadRegisters<uint32_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "LINT") { + return std::make_unique<ReadRegisters<int64_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "ULINT" || type_str == "LWORD") { + return std::make_unique<ReadRegisters<uint64_t>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "REAL") { + return std::make_unique<ReadRegisters<float>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "LREAL") { + return std::make_unique<ReadRegisters<double>>(register_type, transaction_id, unit_id, *start_address, length); + } + + if (type_str == "CHAR") { + return std::make_unique<ReadRegisters<char>>(register_type, transaction_id, unit_id, *start_address, length); + } + + return nullptr; +} +} // namespace + +std::unique_ptr<ReadModbusFunction> ReadModbusFunction::parse(const uint16_t transaction_id, const uint8_t unit_id, const std::string& address) { + static const std::regex address_pattern{R"((holding-register|coil|input-register):(\d+)(:([a-zA-Z_]+))?(\[(\d+)\])?)"}; + + std::smatch matches; + if (std::regex_match(address, matches, address_pattern)) { + if (matches.size() < 7) { + return nullptr; + } + const auto register_type_str = matches[1].str(); + const auto start_address_str = matches[2].str(); + const auto type_str = matches[4].str(); + const auto length_str = matches[6].str(); + + if (register_type_str == "coil") { + return ReadCoilStatus::parse(transaction_id, unit_id, start_address_str, length_str); + } + if (register_type_str == "input-register") { + return parseReadRegister(RegisterType::input, transaction_id, unit_id, start_address_str, type_str, length_str); + } + if (register_type_str == "holding-register") { + return parseReadRegister(RegisterType::holding, transaction_id, unit_id, start_address_str, type_str, length_str); + } + } + + static const std::regex address_pattern_short{R"((\dx|\d)(\d{4,5})?(:([a-zA-Z_]+))?(\[(\d+)\])?)"}; + if (std::regex_match(address, matches, address_pattern_short)) { + if (matches.size() < 7) { + return nullptr; + } + const auto register_type_str = matches[1].str(); + const auto start_address_str = matches[2].str(); + const auto type_str = matches[4].str(); + const auto length_str = matches[6].str(); + + if (register_type_str == "1" || register_type_str == "1x") { + return ReadCoilStatus::parse(transaction_id, unit_id, start_address_str, length_str); + } + if (register_type_str == "3" || register_type_str == "3x") { + return parseReadRegister(RegisterType::input, transaction_id, unit_id, start_address_str, type_str, length_str); + } + if (register_type_str == "4" || register_type_str == "4x") { + return parseReadRegister(RegisterType::holding, transaction_id, unit_id, start_address_str, type_str, length_str); + } + } + + return nullptr; +} +} // namespace org::apache::nifi::minifi::modbus diff --git a/extensions/standard-processors/modbus/ReadModbusFunctions.h b/extensions/standard-processors/modbus/ReadModbusFunctions.h new file mode 100644 index 000000000..d96e3d09b --- /dev/null +++ b/extensions/standard-processors/modbus/ReadModbusFunctions.h @@ -0,0 +1,172 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <regex> +#include <vector> + +#include "core/RecordField.h" +#include "fmt/format.h" +#include "modbus/ByteConverters.h" +#include "modbus/Error.h" +#include "range/v3/algorithm/copy.hpp" +#include "range/v3/view/chunk.hpp" +#include "range/v3/view/drop.hpp" +#include "utils/StringUtils.h" +#include "utils/expected.h" + +namespace org::apache::nifi::minifi::modbus { +enum class RegisterType { + holding, + input +}; + +class ReadModbusFunction { + public: + explicit ReadModbusFunction(const uint16_t transaction_id, const uint8_t unit_id) : transaction_id_(transaction_id), unit_id_(unit_id) { + } + ReadModbusFunction(const ReadModbusFunction&) = delete; + ReadModbusFunction(ReadModbusFunction&&) = delete; + ReadModbusFunction& operator=(ReadModbusFunction&&) = delete; + ReadModbusFunction& operator=(const ReadModbusFunction&) = delete; + + virtual bool operator==(const ReadModbusFunction&) const = 0; + + virtual ~ReadModbusFunction() = default; + + [[nodiscard]] std::vector<std::byte> requestBytes() const; + + [[nodiscard]] uint16_t getTransactionId() const { return transaction_id_; } + [[nodiscard]] uint8_t getUnitId() const { return unit_id_; } + [[nodiscard]] virtual nonstd::expected<core::RecordField, std::error_code> responseToRecordField(std::span<const std::byte> resp_pdu) const = 0; + + static std::unique_ptr<ReadModbusFunction> parse(uint16_t transaction_id, uint8_t unit_id, const std::string& address); + + protected: + [[nodiscard]] auto getRespBytes(std::span<const std::byte> resp_pdu) const -> nonstd::expected<std::span<const std::byte>, std::error_code>; + + [[nodiscard]] virtual std::byte getFunctionCode() const = 0; + [[nodiscard]] virtual std::array<std::byte, 5> rawPdu() const = 0; + [[nodiscard]] virtual uint8_t expectedByteCount() const = 0; + + const uint16_t transaction_id_; + const uint8_t unit_id_; +}; + +class ReadCoilStatus final : public ReadModbusFunction { + public: + ReadCoilStatus(const uint16_t transaction_id, const uint8_t unit_id, const uint16_t starting_address, const uint16_t number_of_points) + : ReadModbusFunction(transaction_id, unit_id), + starting_address_(starting_address), + number_of_points_(number_of_points) { + } + + [[nodiscard]] nonstd::expected<core::RecordField, std::error_code> responseToRecordField(std::span<const std::byte> resp_pdu) const override; + + [[nodiscard]] std::byte getFunctionCode() const override; + [[nodiscard]] std::array<std::byte, 5> rawPdu() const override; + [[nodiscard]] uint8_t expectedByteCount() const override; + + bool operator==(const ReadCoilStatus& rhs) const = default; + bool operator==(const ReadModbusFunction& rhs) const override; + + static std::unique_ptr<ReadModbusFunction> parse(uint16_t transaction_id, uint8_t unit_id, std::string_view address_str, std::string_view length_str); + + private: + uint16_t starting_address_{}; + uint16_t number_of_points_{}; +}; + +template<typename T> +class ReadRegisters final : public ReadModbusFunction { + public: + ReadRegisters(const RegisterType register_type, const uint16_t transaction_id, const uint8_t unit_id, const uint16_t starting_address, const uint16_t number_of_points) + : ReadModbusFunction(transaction_id, unit_id), + register_type_(register_type), + starting_address_(starting_address), + number_of_points_(number_of_points) { + } + + [[nodiscard]] std::array<std::byte, 5> rawPdu() const override { + std::array<std::byte, 5> result; + result[0] = getFunctionCode(); + ranges::copy(toBytes(starting_address_), (result | ranges::views::drop(1) | ranges::views::take(2)).begin()); + ranges::copy(toBytes(wordCount()), (result | ranges::views::drop(3) | ranges::views::take(2)).begin()); + + return result; + } + + [[nodiscard]] uint8_t expectedByteCount() const override { + return gsl::narrow<uint8_t>(number_of_points_*std::max(sizeof(T), sizeof(uint16_t))); + } + + [[nodiscard]] uint16_t wordCount() const { + return expectedByteCount() / sizeof(uint16_t); + } + + [[nodiscard]] nonstd::expected<core::RecordField, std::error_code> responseToRecordField(const std::span<const std::byte> resp_pdu) const override { + const auto resp_bytes = getRespBytes(resp_pdu); + if (!resp_bytes) + return nonstd::make_unexpected(resp_bytes.error()); + + std::vector<T> holding_registers{}; + for (auto&& register_chunk : ranges::views::chunk(*resp_bytes, std::max(sizeof(T), sizeof(uint16_t)))) { + std::array<std::byte, std::max(sizeof(T), sizeof(uint16_t))> register_value{}; + ranges::copy(register_chunk, register_value.begin()); + holding_registers.push_back(fromBytes<T>(std::move(register_value))); + } + if (holding_registers.size() == 1) { + T val = holding_registers.at(0); + return core::RecordField{val}; + } + core::RecordArray record_array; + for (auto holding_register : holding_registers) { + record_array.emplace_back(holding_register); + } + return core::RecordField{std::move(record_array)}; + } + + [[nodiscard]] std::byte getFunctionCode() const override { + switch (register_type_) { + case RegisterType::holding: + return std::byte{0x03}; + case RegisterType::input: + return std::byte{0x04}; + default: + throw std::invalid_argument(fmt::format("Invalid RegisterType {}", magic_enum::enum_underlying(register_type_))); + } + } + + bool operator==(const ReadModbusFunction& rhs) const override { + const auto read_holding_registers_rhs = dynamic_cast<const ReadRegisters<T>*>(&rhs); + if (!read_holding_registers_rhs) + return false; + + return read_holding_registers_rhs->number_of_points_ == this->number_of_points_ && + read_holding_registers_rhs->starting_address_ == this->starting_address_ && + read_holding_registers_rhs->transaction_id_ == this->transaction_id_; + } + + protected: + RegisterType register_type_{}; + uint16_t starting_address_{}; + uint16_t number_of_points_{}; +}; + +} // namespace org::apache::nifi::minifi::modbus diff --git a/extensions/standard-processors/processors/PutTCP.cpp b/extensions/standard-processors/processors/PutTCP.cpp index 5593415fc..75f6a860b 100644 --- a/extensions/standard-processors/processors/PutTCP.cpp +++ b/extensions/standard-processors/processors/PutTCP.cpp @@ -29,15 +29,9 @@ #include "utils/net/AsioCoro.h" #include "utils/net/AsioSocketUtils.h" -using asio::ip::tcp; - using namespace std::literals::chrono_literals; -using std::chrono::steady_clock; -using org::apache::nifi::minifi::utils::net::use_nothrow_awaitable; -using org::apache::nifi::minifi::utils::net::HandshakeType; using org::apache::nifi::minifi::utils::net::TcpSocket; using org::apache::nifi::minifi::utils::net::SslSocket; -using org::apache::nifi::minifi::utils::net::asyncOperationWithTimeout; namespace org::apache::nifi::minifi::processors { @@ -63,12 +57,12 @@ void PutTCP::onSchedule(core::ProcessContext& context, core::ProcessSessionFacto if (context.getProperty(Port).value_or(std::string{}).empty()) { throw Exception{ExceptionType::PROCESSOR_EXCEPTION, "missing port"}; } - if (auto idle_connection_expiration = context.getProperty<core::TimePeriodValue>(IdleConnectionExpiration); idle_connection_expiration && idle_connection_expiration->getMilliseconds() > 0ms) + if (const auto idle_connection_expiration = context.getProperty<core::TimePeriodValue>(IdleConnectionExpiration); idle_connection_expiration && idle_connection_expiration->getMilliseconds() > 0ms) idle_connection_expiration_ = idle_connection_expiration->getMilliseconds(); else idle_connection_expiration_.reset(); - if (auto timeout = context.getProperty<core::TimePeriodValue>(Timeout); timeout && timeout->getMilliseconds() > 0ms) + if (const auto timeout = context.getProperty<core::TimePeriodValue>(Timeout); timeout && timeout->getMilliseconds() > 0ms) timeout_duration_ = timeout->getMilliseconds(); else timeout_duration_ = 15s; @@ -82,7 +76,7 @@ void PutTCP::onSchedule(core::ProcessContext& context, core::ProcessSessionFacto ssl_context_.reset(); if (context.getProperty(SSLContextService, context_name) && !IsNullOrEmpty(context_name)) { if (auto controller_service = context.getControllerService(context_name)) { - if (auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(context.getControllerService(context_name))) { + if (const auto ssl_context_service = std::dynamic_pointer_cast<minifi::controllers::SSLContextService>(context.getControllerService(context_name))) { ssl_context_ = utils::net::getSslContext(*ssl_context_service); } else { throw Exception(PROCESS_SCHEDULE_EXCEPTION, context_name + " is not an SSL Context Service"); @@ -95,180 +89,12 @@ void PutTCP::onSchedule(core::ProcessContext& context, core::ProcessSessionFacto const auto delimiter_str = context.getProperty(OutgoingMessageDelimiter).value_or(std::string{}); delimiter_ = utils::span_to<std::vector>(as_bytes(std::span(delimiter_str))); - if (auto max_size_of_socket_send_buffer = context.getProperty<core::DataSizeValue>(MaxSizeOfSocketSendBuffer)) + if (const auto max_size_of_socket_send_buffer = context.getProperty<core::DataSizeValue>(MaxSizeOfSocketSendBuffer)) max_size_of_socket_send_buffer_ = max_size_of_socket_send_buffer->getValue(); else max_size_of_socket_send_buffer_.reset(); } -namespace { - -template<class SocketType> -class ConnectionHandler : public ConnectionHandlerBase { - public: - ConnectionHandler(utils::net::ConnectionId connection_id, - std::chrono::milliseconds timeout, - std::shared_ptr<core::logging::Logger> logger, - std::optional<size_t> max_size_of_socket_send_buffer, - asio::ssl::context* ssl_context) - : connection_id_(std::move(connection_id)), - timeout_duration_(timeout), - logger_(std::move(logger)), - max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer), - ssl_context_(ssl_context) { - } - - ConnectionHandler(ConnectionHandler&&) = delete; - ConnectionHandler(const ConnectionHandler&) = delete; - ConnectionHandler& operator=(ConnectionHandler&&) = delete; - ConnectionHandler& operator=(const ConnectionHandler&) = delete; - - ~ConnectionHandler() override { - shutdownSocket(); - } - - asio::awaitable<std::error_code> sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send, - const std::vector<std::byte>& delimiter, - asio::io_context& io_context_) override; - - private: - [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const override { - return last_used_ && *last_used_ >= (steady_clock::now() - dur); - } - - void reset() override { - last_used_.reset(); - socket_.reset(); - } - - [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); } - [[nodiscard]] asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context); - [[nodiscard]] bool hasUsableSocket() const { return socket_ && socket_->lowest_layer().is_open(); } - - asio::awaitable<std::error_code> establishNewConnection(const tcp::resolver::results_type& endpoints, asio::io_context& io_context_); - asio::awaitable<std::error_code> send(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter); - - SocketType createNewSocket(asio::io_context& io_context_); - void shutdownSocket(); - - utils::net::ConnectionId connection_id_; - std::optional<SocketType> socket_; - - std::optional<steady_clock::time_point> last_used_; - asio::steady_timer::duration timeout_duration_; - - std::shared_ptr<core::logging::Logger> logger_; - std::optional<size_t> max_size_of_socket_send_buffer_; - - asio::ssl::context* ssl_context_; -}; - -template<> -TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context& io_context_) { - gsl_Expects(!ssl_context_); - return TcpSocket{io_context_}; -} - -template<> -SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context& io_context_) { - gsl_Expects(ssl_context_); - return {io_context_, *ssl_context_}; -} - -template<> -void ConnectionHandler<TcpSocket>::shutdownSocket() { -} - -template<> -void ConnectionHandler<SslSocket>::shutdownSocket() { - gsl_Expects(ssl_context_); - if (socket_) { - asio::error_code ec; - socket_->lowest_layer().cancel(ec); - if (ec) { - logger_->log_error("Cancelling asynchronous operations of SSL socket failed with: {}", ec.message()); - } - socket_->shutdown(ec); - if (ec) { - logger_->log_error("Shutdown of SSL socket failed with: {}", ec.message()); - } - } -} - -template<class SocketType> -asio::awaitable<std::error_code> ConnectionHandler<SocketType>::establishNewConnection(const tcp::resolver::results_type& endpoints, asio::io_context& io_context) { - auto socket = createNewSocket(io_context); - std::error_code last_error; - for (const auto& endpoint : endpoints) { - auto [connection_error] = co_await asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, use_nothrow_awaitable), timeout_duration_); - if (connection_error) { - logger_->log_debug("Connecting to {} failed due to {}", endpoint.endpoint(), connection_error.message()); - last_error = connection_error; - continue; - } - auto [handshake_error] = co_await utils::net::handshake(socket, timeout_duration_); - if (handshake_error) { - logger_->log_debug("Handshake with {} failed due to {}", endpoint.endpoint(), handshake_error.message()); - last_error = handshake_error; - continue; - } - if (max_size_of_socket_send_buffer_) - socket.lowest_layer().set_option(TcpSocket::send_buffer_size(gsl::narrow<int>(*max_size_of_socket_send_buffer_))); - socket_.emplace(std::move(socket)); - co_return std::error_code(); - } - co_return last_error; -} - -template<class SocketType> -[[nodiscard]] asio::awaitable<std::error_code> ConnectionHandler<SocketType>::setupUsableSocket(asio::io_context& io_context) { - if (hasUsableSocket()) - co_return std::error_code(); - tcp::resolver resolver(io_context); - auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout( - resolver.async_resolve(connection_id_.getHostname(), connection_id_.getService(), use_nothrow_awaitable), timeout_duration_); - if (resolve_error) - co_return resolve_error; - co_return co_await establishNewConnection(resolve_result, io_context); -} - -template<class SocketType> -asio::awaitable<std::error_code> ConnectionHandler<SocketType>::sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send, - const std::vector<std::byte>& delimiter, - asio::io_context& io_context) { - if (auto connection_error = co_await setupUsableSocket(io_context)) // NOLINT - co_return connection_error; - co_return co_await send(stream_to_send, delimiter); -} - -template<class SocketType> -asio::awaitable<std::error_code> ConnectionHandler<SocketType>::send(const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter) { - gsl_Expects(hasUsableSocket()); - - std::vector<std::byte> data_chunk; - data_chunk.resize(chunk_size); - std::span<std::byte> buffer{data_chunk}; - while (stream_to_send->tell() < stream_to_send->size()) { - size_t num_read = stream_to_send->read(buffer); - if (io::isError(num_read)) - co_return std::make_error_code(std::errc::io_error); - auto [write_error, bytes_written] = co_await asyncOperationWithTimeout( - asio::async_write(*socket_, asio::buffer(data_chunk, num_read), use_nothrow_awaitable), timeout_duration_); - if (write_error) - co_return write_error; - logger_->log_trace("Writing flowfile({} bytes) to socket succeeded", bytes_written); - } - auto [delimiter_write_error, delimiter_bytes_written] = co_await asyncOperationWithTimeout( - asio::async_write(*socket_, asio::buffer(delimiter), use_nothrow_awaitable), timeout_duration_); - if (delimiter_write_error) - co_return delimiter_write_error; - logger_->log_trace("Writing delimiter({} bytes) to socket succeeded", delimiter_bytes_written); - - last_used_ = steady_clock::now(); - co_return std::error_code(); -} -} // namespace - void PutTCP::onTrigger(core::ProcessContext& context, core::ProcessSession& session) { const auto flow_file = session.get(); if (!flow_file) { @@ -289,12 +115,12 @@ void PutTCP::onTrigger(core::ProcessContext& context, core::ProcessSession& sess } auto connection_id = utils::net::ConnectionId(std::move(hostname), std::move(port)); - std::shared_ptr<ConnectionHandlerBase> handler; + std::shared_ptr<utils::net::ConnectionHandlerBase> handler; if (!connections_ || !connections_->contains(connection_id)) { if (ssl_context_) - handler = std::make_shared<ConnectionHandler<SslSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, &*ssl_context_); + handler = std::make_shared<utils::net::ConnectionHandler<SslSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, &*ssl_context_); else - handler = std::make_shared<ConnectionHandler<TcpSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, nullptr); + handler = std::make_shared<utils::net::ConnectionHandler<TcpSocket>>(connection_id, timeout_duration_, logger_, max_size_of_socket_send_buffer_, nullptr); if (connections_) (*connections_)[connection_id] = handler; } else { @@ -315,23 +141,49 @@ void PutTCP::removeExpiredConnections() { } } -std::error_code PutTCP::sendFlowFileContent(std::shared_ptr<ConnectionHandlerBase>& connection_handler, +std::error_code PutTCP::sendFlowFileContent(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, const std::shared_ptr<io::InputStream>& flow_file_content_stream) { std::error_code operation_error; io_context_.restart(); asio::co_spawn(io_context_, - connection_handler->sendStreamWithDelimiter(flow_file_content_stream, delimiter_, io_context_), - [&operation_error](const std::exception_ptr&, std::error_code error_code) { - operation_error = error_code; - }); + sendStreamWithDelimiter(*connection_handler, flow_file_content_stream, delimiter_), + [&operation_error](const std::exception_ptr&, const std::error_code error_code) { + operation_error = error_code; + }); io_context_.run(); return operation_error; } -void PutTCP::processFlowFile(std::shared_ptr<ConnectionHandlerBase>& connection_handler, +asio::awaitable<std::error_code> PutTCP::sendStreamWithDelimiter(utils::net::ConnectionHandlerBase& connection_handler, + const std::shared_ptr<io::InputStream>& stream_to_send, const std::vector<std::byte>& delimiter) { + if (auto connection_error = co_await connection_handler.setupUsableSocket(io_context_)) { // NOLINT (clang tidy doesnt like coroutines) + co_return connection_error; + } + + std::vector<std::byte> data_chunk; + data_chunk.resize(chunk_size); + const std::span<std::byte> buffer{data_chunk}; + while (stream_to_send->tell() < stream_to_send->size()) { + const size_t num_read = stream_to_send->read(buffer); + if (io::isError(num_read)) + co_return std::make_error_code(std::errc::io_error); + auto [write_error, bytes_written] = co_await connection_handler.write(asio::buffer(data_chunk, num_read)); + if (write_error) + co_return write_error; + logger_->log_trace("Writing flowfile({} bytes) to socket succeeded", bytes_written); + } + auto [delimiter_write_error, delimiter_bytes_written] = co_await connection_handler.write(asio::buffer(delimiter)); + if (delimiter_write_error) + co_return delimiter_write_error; + logger_->log_trace("Writing delimiter({} bytes) to socket succeeded", delimiter_bytes_written); + + co_return std::error_code(); +} + +void PutTCP::processFlowFile(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, core::ProcessSession& session, const std::shared_ptr<core::FlowFile>& flow_file) { - auto flow_file_content_stream = session.getFlowFileContentStream(*flow_file); + const auto flow_file_content_stream = session.getFlowFileContentStream(*flow_file); if (!flow_file_content_stream) { session.transfer(flow_file, Failure); return; diff --git a/extensions/standard-processors/processors/PutTCP.h b/extensions/standard-processors/processors/PutTCP.h index a3c6e3a70..79446c733 100644 --- a/extensions/standard-processors/processors/PutTCP.h +++ b/extensions/standard-processors/processors/PutTCP.h @@ -33,26 +33,15 @@ #include "core/PropertyDefinitionBuilder.h" #include "core/PropertyType.h" #include "core/RelationshipDefinition.h" -#include "utils/expected.h" #include "utils/StringUtils.h" // for string <=> on libc++ #include "utils/net/AsioSocketUtils.h" +#include "utils/net/ConnectionHandler.h" #include <asio/io_context.hpp> -#include <asio/awaitable.hpp> #include <asio/ssl/context.hpp> namespace org::apache::nifi::minifi::processors { -class ConnectionHandlerBase { - public: - virtual ~ConnectionHandlerBase() = default; - virtual void reset() = 0; - - [[nodiscard]] virtual bool hasBeenUsed() const = 0; - [[nodiscard]] virtual bool hasBeenUsedIn(std::chrono::milliseconds dur) const = 0; - [[nodiscard]] virtual asio::awaitable<std::error_code> sendStreamWithDelimiter(const std::shared_ptr<io::InputStream>& stream_to_send, - const std::vector<std::byte>& delimiter, - asio::io_context& io_context) = 0; -}; + class PutTCP final : public core::Processor { public: @@ -138,7 +127,9 @@ class PutTCP final : public core::Processor { explicit PutTCP(const std::string& name, const utils::Identifier& uuid = {}); PutTCP(const PutTCP&) = delete; + PutTCP(PutTCP&&) = delete; PutTCP& operator=(const PutTCP&) = delete; + PutTCP& operator=(PutTCP&&) = delete; ~PutTCP() final; void initialize() final; @@ -148,16 +139,20 @@ class PutTCP final : public core::Processor { private: void removeExpiredConnections(); - void processFlowFile(std::shared_ptr<ConnectionHandlerBase>& connection_handler, + void processFlowFile(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, core::ProcessSession& session, const std::shared_ptr<core::FlowFile>& flow_file); - std::error_code sendFlowFileContent(std::shared_ptr<ConnectionHandlerBase>& connection_handler, + std::error_code sendFlowFileContent(const std::shared_ptr<utils::net::ConnectionHandlerBase>& connection_handler, const std::shared_ptr<io::InputStream>& flow_file_content_stream); + asio::awaitable<std::error_code> sendStreamWithDelimiter(utils::net::ConnectionHandlerBase& connection_handler, + const std::shared_ptr<io::InputStream>& stream_to_send, + const std::vector<std::byte>& delimiter); + std::vector<std::byte> delimiter_; asio::io_context io_context_; - std::optional<std::unordered_map<utils::net::ConnectionId, std::shared_ptr<ConnectionHandlerBase>>> connections_; + std::optional<std::unordered_map<utils::net::ConnectionId, std::shared_ptr<utils::net::ConnectionHandlerBase>>> connections_; std::optional<std::chrono::milliseconds> idle_connection_expiration_; std::optional<size_t> max_size_of_socket_send_buffer_; std::chrono::milliseconds timeout_duration_ = std::chrono::seconds(15); diff --git a/extensions/standard-processors/tests/CMakeLists.txt b/extensions/standard-processors/tests/CMakeLists.txt index cf22f305b..4eff16145 100644 --- a/extensions/standard-processors/tests/CMakeLists.txt +++ b/extensions/standard-processors/tests/CMakeLists.txt @@ -21,7 +21,7 @@ include(Coroutines) include(JoltTests) enable_coroutines() -file(GLOB PROCESSOR_UNIT_TESTS "unit/*.cpp") +file(GLOB PROCESSOR_UNIT_TESTS "unit/*.cpp" "unit/modbus/*.cpp") file(GLOB PROCESSOR_INTEGRATION_TESTS "integration/*.cpp") SET(PROCESSOR_INT_TEST_COUNT 0) @@ -49,7 +49,7 @@ FOREACH(testfile ${PROCESSOR_UNIT_TESTS}) MATH(EXPR PROCESSOR_INT_TEST_COUNT "${PROCESSOR_INT_TEST_COUNT}+1") ENDFOREACH() -message("-- Finished building ${PROCESSOR_INT_TEST_COUNT} processor unit test file(s)...") +message("-- Finished building ${PROCESSOR_UNIT_TESTS} processor unit test file(s)...") SET(INT_TEST_COUNT 0) diff --git a/extensions/standard-processors/tests/unit/JsonRecordTests.cpp b/extensions/standard-processors/tests/unit/JsonRecordTests.cpp index 979b5ae64..406ae5b5d 100644 --- a/extensions/standard-processors/tests/unit/JsonRecordTests.cpp +++ b/extensions/standard-processors/tests/unit/JsonRecordTests.cpp @@ -29,17 +29,17 @@ namespace org::apache::nifi::minifi::standard::test { -constexpr std::string_view record_per_line_str = R"({"baz":3.14,"qux":[true,false,true],"is_test":true,"bar":123,"quux":{"Aprikose":"apricot","Birne":"pear","Apfel":"apple"},"foo":"asd","when":"2012-07-01T09:53:00Z"} -{"baz":3.141592653589793,"qux":[false,false,true],"is_test":true,"bar":98402134,"quux":{"Aprikose":"abricot","Birne":"poire","Apfel":"pomme"},"foo":"Lorem ipsum dolor sit amet, consectetur adipiscing elit.","when":"2022-11-01T19:52:11Z"} +constexpr std::string_view record_per_line_str = R"({"baz":3.14,"qux":["a","b","c"],"corge":[true,false],"is_test":true,"bar":123,"quux":{"Aprikose":"apricot","Birne":"pear","Apfel":"apple"},"foo":"asd","when":"2012-07-01T09:53:00Z"} +{"baz":3.141592653589793,"qux":["x","y","z"],"corge":[false,false],"is_test":true,"bar":98402134,"quux":{"Aprikose":"abricot","Birne":"poire","Apfel":"pomme"},"foo":"Lorem ipsum dolor sit amet, consectetur adipiscing elit.","when":"2022-11-01T19:52:11Z"} )"; -constexpr std::string_view array_compressed_str = R"([{"baz":3.14,"qux":[true,false,true],"is_test":true,"bar":123,"quux":{"Aprikose":"apricot","Birne":"pear","Apfel":"apple"},"foo":"asd","when":"2012-07-01T09:53:00Z"},{"baz":3.141592653589793,"qux":[false,false,true],"is_test":true,"bar":98402134,"quux":{"Aprikose":"abricot","Birne":"poire","Apfel":"pomme"},"foo":"Lorem ipsum dolor sit amet, consectetur adipiscing elit.","when":"2022-11-01T19:52:11Z"}])"; +constexpr std::string_view array_compressed_str = R"([{"baz":3.14,"qux":["a","b","c"],"corge":[true,false],"is_test":true,"bar":123,"quux":{"Aprikose":"apricot","Birne":"pear","Apfel":"apple"},"foo":"asd","when":"2012-07-01T09:53:00Z"},{"baz":3.141592653589793,"qux":["x","y","z"],"corge":[false,false],"is_test":true,"bar":98402134,"quux":{"Aprikose":"abricot","Birne":"poire","Apfel":"pomme"},"foo":"Lorem ipsum dolor sit amet, consectetur adipiscing elit.","when":"2022-11-01T19:52:11Z"}])"; constexpr std::string_view array_pretty_str = R"([ { "baz": 3.14, "qux": [ - true, - false, - true + "a", + "b", + "c" ], "is_test": true, "bar": 123, @@ -48,15 +48,19 @@ constexpr std::string_view array_pretty_str = R"([ "Birne": "pear", "Apfel": "apple" }, + "corge": [ + true, + false + ], "foo": "asd", "when": "2012-07-01T09:53:00Z" }, { "baz": 3.141592653589793, "qux": [ - false, - false, - true + "x", + "y", + "z" ], "is_test": true, "bar": 98402134, @@ -65,6 +69,10 @@ constexpr std::string_view array_pretty_str = R"([ "Birne": "poire", "Apfel": "pomme" }, + "corge": [ + false, + false + ], "foo": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", "when": "2022-11-01T19:52:11Z" } diff --git a/extensions/standard-processors/tests/unit/modbus/ModbusTests.cpp b/extensions/standard-processors/tests/unit/modbus/ModbusTests.cpp new file mode 100644 index 000000000..295dfbbd5 --- /dev/null +++ b/extensions/standard-processors/tests/unit/modbus/ModbusTests.cpp @@ -0,0 +1,259 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <numbers> + +#include "modbus/ReadModbusFunctions.h" +#include "unit/Catch.h" + + +namespace org::apache::nifi::minifi::modbus::test { + +template <typename... Bytes> +std::vector<std::byte> createByteVector(Bytes... bytes) { + return {static_cast<std::byte>(bytes)...}; +} + +template<size_t Size, typename... Bytes> +std::array<std::byte, Size> createByteArray(Bytes... bytes) { + return {static_cast<std::byte>(bytes)...}; +} + +TEST_CASE("ReadCoilStatus") { + const auto read_coil_status = ReadCoilStatus(280, 0, 19, 19); + { + { + CHECK(read_coil_status.rawPdu() == createByteArray<5>(0x01, 0x00, 0x13, 0x00, 0x13)); + CHECK(read_coil_status.requestBytes() == createByteVector(0x01, 0x18, 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x13, 0x00, 0x13)); + } + + auto serialized_response = read_coil_status.responseToRecordField(createByteVector(0x01, 0x03, 0xCD, 0x6B, 0x05)); + REQUIRE(serialized_response.has_value()); + auto record_array = core::RecordArray(); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(true); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(false); + record_array.emplace_back(true); + record_array.emplace_back(true); + record_array.emplace_back(true); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(true); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(true); + record_array.emplace_back(false); + record_array.emplace_back(true); + + CHECK(std::get<core::RecordArray>(serialized_response->value_) == record_array); + } + + { + auto shorter_than_expected_resp = read_coil_status.responseToRecordField(createByteVector(0x01, 0x02, 0xCD, 0x6B)); + REQUIRE(!shorter_than_expected_resp); + CHECK_THAT(shorter_than_expected_resp.error(), minifi::test::MatchesError(modbus::ModbusExceptionCode::InvalidResponse)); + } + + + { + auto longer_than_expected_resp = read_coil_status.responseToRecordField(createByteVector(0x01, 0x04, 0xCD, 0x6B, 0x05, 0x07)); + REQUIRE(!longer_than_expected_resp); + CHECK_THAT(longer_than_expected_resp.error(), minifi::test::MatchesError(modbus::ModbusExceptionCode::InvalidResponse)); + } + + { + auto mismatching_size_resp = read_coil_status.responseToRecordField(createByteVector(0x01, 0x03, 0xCD, 0x6B, 0x05, 0x07)); + REQUIRE(!mismatching_size_resp); + CHECK_THAT(mismatching_size_resp.error(), minifi::test::MatchesError(modbus::ModbusExceptionCode::UnexpectedResponsePDUSize)); + } +} + +TEST_CASE("ReadHoldingRegisters uint16_t") { + { + const auto read_holding_registers = ReadRegisters<uint16_t>(RegisterType::holding, 0, 0, 5, 3); + { + CHECK(read_holding_registers.rawPdu() == createByteArray<5>(0x03, 0x00, 0x05, 0x00, 0x03)); + } + + auto serialized_response = read_holding_registers.responseToRecordField(createByteVector(0x03, 0x06, 0x3A, 0x98, 0x13, 0x88, 0x00, 0xC8)); + REQUIRE(serialized_response.has_value()); + auto record_array = core::RecordArray(); + record_array.emplace_back(15000); + record_array.emplace_back(5000); + record_array.emplace_back(200); + CHECK(std::get<core::RecordArray>(serialized_response->value_) == record_array); + } +} + +TEST_CASE("ReadHoldingRegisters char") { + { + const auto read_holding_registers = ReadRegisters<char>(RegisterType::holding, 0, 0, 5, 3); + { + CHECK(read_holding_registers.rawPdu() == createByteArray<5>(0x03, 0x00, 0x05, 0x00, 0x03)); + } + + auto serialized_response = read_holding_registers.responseToRecordField(createByteVector(0x03, 0x06, 0x00, 0x66, 0x00, 0x6F, 0x00, 0x6F)); + REQUIRE(serialized_response.has_value()); + auto record_array = core::RecordArray(); + record_array.emplace_back('f'); + record_array.emplace_back('o'); + record_array.emplace_back('o'); + CHECK(std::get<core::RecordArray>(serialized_response->value_) == record_array); + } +} + +TEST_CASE("ReadInputRegisters") { + { + const auto read_input_registers = ReadRegisters<uint16_t>(RegisterType::input, 0, 0, 5, 3); + { + CHECK(read_input_registers.rawPdu() == createByteArray<5>(0x04, 0x00, 0x05, 0x00, 0x03)); + } + auto serialized_response = read_input_registers.responseToRecordField(createByteVector(0x04, 0x06, 0x3A, 0x98, 0x13, 0x88, 0x00, 0xC8)); + REQUIRE(serialized_response.has_value()); + auto record_array = core::RecordArray(); + record_array.emplace_back(15000); + record_array.emplace_back(5000); + record_array.emplace_back(200); + CHECK(std::get<core::RecordArray>(serialized_response->value_) == record_array); + } +} + +TEST_CASE("ByteConversion") { + { + constexpr std::array from{std::byte{0x12}, std::byte{0x24}}; + + CHECK(4644 == modbus::fromBytes<uint16_t>(from)); + CHECK(9234 == modbus::fromBytes<uint16_t, std::endian::little>(from)); + } + + { + constexpr std::array from{std::byte{0x00}, std::byte{0x61}}; + CHECK('a' == modbus::fromBytes<char>(from)); + } + + { + constexpr std::array from{std::byte{0x61}, std::byte{0x00}}; + CHECK('\0' == modbus::fromBytes<char>(from)); + } + + { + constexpr std::array from{std::byte{0x1A}, std::byte{0x45}, std::byte{0x02}, std::byte{0x3F}}; + CHECK(440730175 == modbus::fromBytes<uint32_t>(from)); + } + + { + constexpr std::array from{std::byte{0x40}, std::byte{0xD8}, std::byte{0xF5}, std::byte{0xC3}}; + CHECK(6.78F == modbus::fromBytes<float>(from)); + } + + { + constexpr std::array<std::byte, 8> pi_double{ + std::byte{0x40}, std::byte{0x09}, std::byte{0x21}, std::byte{0xFB}, + std::byte{0x54}, std::byte{0x44}, std::byte{0x2d}, std::byte{0x18}}; + CHECK(std::numbers::pi == modbus::fromBytes<double>(pi_double)); + } +} + +TEST_CASE("ParseAddress") { + constexpr uint16_t transaction_id = 1; + constexpr uint8_t unit_id = 0; + { + auto expected = ReadRegisters<uint16_t>(RegisterType::holding, transaction_id, unit_id, 20, 10); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "holding-register:20:UINT[10]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "400020:UINT[10]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x00020:UINT[10]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "40020:UINT[10]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x0020:UINT[10]") == expected); + } + + { + auto expected = ReadRegisters<uint16_t>(RegisterType::holding, transaction_id, unit_id, 5678, 1); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "holding-register:5678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "405678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x05678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "45678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x5678") == expected); + } + + { + auto expected = ReadRegisters<uint16_t>(RegisterType::input, transaction_id, unit_id, 5678, 1); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "input-register:5678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "305678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "3x05678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "35678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "3x5678") == expected); + } + + { + auto expected = ReadRegisters<char>(RegisterType::holding, transaction_id, unit_id, 5678, 1); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "holding-register:5678:CHAR") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "405678:CHAR") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x05678:CHAR") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "45678:CHAR") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x5678:CHAR") == expected); + } + + { + auto expected = ReadRegisters<float>(RegisterType::holding, transaction_id, unit_id, 7777, 2); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "holding-register:7777:REAL[2]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "407777:REAL[2]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x07777:REAL[2]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "47777:REAL[2]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "4x7777:REAL[2]") == expected); + } + + { + auto expected = ReadRegisters<uint16_t>(RegisterType::input, transaction_id, unit_id, 5678, 1); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "input-register:5678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "305678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "3x05678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "35678") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "3x5678") == expected); + } + + { + auto expected = ReadCoilStatus(transaction_id, unit_id, 4234, 1); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "coil:4234") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "104234") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "1x04234") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "14234") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "1x4234") == expected); + } + + { + auto expected = ReadCoilStatus(transaction_id, unit_id, 222, 12); + + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "coil:222[12]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "100222[12]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "1x00222[12]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "10222[12]") == expected); + CHECK(*ReadModbusFunction::parse(transaction_id, unit_id, "1x0222[12]") == expected); + } +} + +} // namespace org::apache::nifi::minifi::modbus::test diff --git a/libminifi/include/core/Record.h b/libminifi/include/core/Record.h index 632761212..7e2a5a89d 100644 --- a/libminifi/include/core/Record.h +++ b/libminifi/include/core/Record.h @@ -26,10 +26,17 @@ namespace org::apache::nifi::minifi::core { -class Record { +class Record final { public: Record() = default; - Record(Record&& rhs) noexcept : fields_(std::move(rhs.fields_)) {} + Record(Record&& rhs) noexcept = default; + Record& operator=(Record&& rhs) noexcept = default; + + Record(const Record&) = delete; + Record& operator=(const Record&) = delete; + + ~Record() = default; + auto emplace(std::string key, RecordField&& field) { return fields_.emplace(std::move(key), std::move(field)); } diff --git a/libminifi/include/core/RecordField.h b/libminifi/include/core/RecordField.h index f3fbf0058..ae89e67c2 100644 --- a/libminifi/include/core/RecordField.h +++ b/libminifi/include/core/RecordField.h @@ -23,6 +23,7 @@ #include <memory> #include <variant> #include <chrono> +#include <concepts> namespace org::apache::nifi::minifi::core { @@ -49,12 +50,27 @@ using RecordArray = std::vector<RecordField>; using RecordObject = std::unordered_map<std::string, BoxedRecordField>; +template<typename T> +concept Float = std::is_floating_point_v<T>; + +template<typename T> +concept Integer = std::integral<T>; + struct RecordField { - explicit RecordField(std::variant<std::string, int64_t, double, bool, std::chrono::system_clock::time_point, RecordArray, RecordObject> value) : value_(std::move(value)) {} - RecordField(const RecordField& field) = delete; + explicit RecordField(RecordObject ro) : value_(std::move(ro)) {} + explicit RecordField(RecordArray ra) : value_(std::move(ra)) {} + explicit RecordField(std::string s) : value_(std::move(s)) {} + explicit RecordField(std::chrono::system_clock::time_point tp) : value_(tp) {} + explicit RecordField(bool b) : value_(b) {} + explicit RecordField(const char c) : value_(std::string{c}) {} + explicit RecordField(uint64_t u64) : value_(u64) {} + explicit RecordField(Integer auto i64) : value_(int64_t{i64}) {} + explicit RecordField(Float auto f) : value_(f) {} + + RecordField(const RecordField& field) = default; RecordField(RecordField&& field) noexcept : value_(std::move(field.value_)) {} - RecordField& operator=(const RecordField&) = delete; + RecordField& operator=(const RecordField&) = default; RecordField& operator=(RecordField&& field) noexcept { value_ = std::move(field.value_); return *this; @@ -65,7 +81,7 @@ struct RecordField { bool operator==(const RecordField& rhs) const = default; - std::variant<std::string, int64_t, double, bool, std::chrono::system_clock::time_point, RecordArray, RecordObject> value_; + std::variant<std::string, int64_t, uint64_t, double, bool, std::chrono::system_clock::time_point, RecordArray, RecordObject> value_; }; inline bool BoxedRecordField::operator==(const BoxedRecordField& rhs) const { diff --git a/libminifi/include/utils/StringUtils.h b/libminifi/include/utils/StringUtils.h index 2b78153e5..8d83ace8e 100644 --- a/libminifi/include/utils/StringUtils.h +++ b/libminifi/include/utils/StringUtils.h @@ -18,6 +18,7 @@ #include <algorithm> #include <cstring> +#include <charconv> #include <functional> #include <iostream> #include <map> @@ -432,6 +433,22 @@ struct ParseError {}; nonstd::expected<std::optional<char>, ParseError> parseCharacter(std::string_view input); std::string replaceEscapedCharacters(std::string_view input); + +// no std::arithmetic yet +template <typename T> concept arithmetic = std::integral<T> || std::floating_point<T>; + +template<arithmetic T> +nonstd::expected<T, ParseError> parseNumber(std::string_view input) { + T t{}; + const auto [ptr, ec] = std::from_chars(input.data(), input.data() + input.size(), t); + if (ec != std::errc()) { + return nonstd::make_unexpected(ParseError{}); + } + if (ptr != input.data() + input.size()) { + return nonstd::make_unexpected(ParseError{}); + } + return t; +} } // namespace string } // namespace org::apache::nifi::minifi::utils diff --git a/libminifi/include/utils/net/AsioSocketUtils.h b/libminifi/include/utils/net/AsioSocketUtils.h index 39b05f322..ceaf9fcb4 100644 --- a/libminifi/include/utils/net/AsioSocketUtils.h +++ b/libminifi/include/utils/net/AsioSocketUtils.h @@ -51,6 +51,8 @@ class ConnectionId { ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), service_(std::move(port)) {} ConnectionId(const ConnectionId& connection_id) = default; ConnectionId(ConnectionId&& connection_id) = default; + ConnectionId& operator=(ConnectionId&&) = default; + ConnectionId& operator=(const ConnectionId&) = default; auto operator<=>(const ConnectionId&) const = default; @@ -166,16 +168,14 @@ class AsioSocketConnection : public io::BaseStream { } // namespace org::apache::nifi::minifi::utils::net -namespace std { template<> -struct hash<org::apache::nifi::minifi::utils::net::ConnectionId> { - size_t operator()(const org::apache::nifi::minifi::utils::net::ConnectionId& connection_id) const { +struct std::hash<org::apache::nifi::minifi::utils::net::ConnectionId> { + size_t operator()(const org::apache::nifi::minifi::utils::net::ConnectionId& connection_id) const noexcept { return org::apache::nifi::minifi::utils::hash_combine( std::hash<std::string_view>{}(connection_id.getHostname()), std::hash<std::string_view>{}(connection_id.getService())); } }; -} // namespace std template <typename InternetProtocol> struct fmt::formatter<asio::ip::basic_endpoint<InternetProtocol>> : fmt::ostream_formatter {}; diff --git a/libminifi/include/utils/net/ConnectionHandler.h b/libminifi/include/utils/net/ConnectionHandler.h new file mode 100644 index 000000000..d7e982560 --- /dev/null +++ b/libminifi/include/utils/net/ConnectionHandler.h @@ -0,0 +1,173 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <asio/read.hpp> + +#include "AsioCoro.h" +#include "AsioSocketUtils.h" +#include "ConnectionHandlerBase.h" + +namespace org::apache::nifi::minifi::utils::net { + +template<class SocketType> +class ConnectionHandler final : public ConnectionHandlerBase { + public: + ConnectionHandler(ConnectionId connection_id, + const std::chrono::milliseconds timeout, + std::shared_ptr<core::logging::Logger> logger, + const std::optional<size_t> max_size_of_socket_send_buffer, + asio::ssl::context* ssl_context) + : connection_id_(std::move(connection_id)), + timeout_duration_(timeout), + logger_(std::move(logger)), + max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer), + ssl_context_(ssl_context) { + } + + ConnectionHandler(ConnectionHandler&&) = delete; + ConnectionHandler(const ConnectionHandler&) = delete; + ConnectionHandler& operator=(ConnectionHandler&&) = delete; + ConnectionHandler& operator=(const ConnectionHandler&) = delete; + + ~ConnectionHandler() override { + shutdownSocket(); + } + + + private: + [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const override { + return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() - dur); + } + + void reset() override { + last_used_.reset(); + socket_.reset(); + } + + [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); } + [[nodiscard]] asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context) override; + [[nodiscard]] bool hasUsableSocket() const { return socket_ && socket_->lowest_layer().is_open(); } + + asio::awaitable<std::error_code> establishNewConnection(const asio::ip::tcp::resolver::results_type& endpoints, asio::io_context& io_context_); + [[nodiscard]] asio::awaitable<std::tuple<std::error_code, size_t>> write(const asio::const_buffer& buffer) override; + [[nodiscard]] asio::awaitable<std::tuple<std::error_code, size_t>> read(asio::mutable_buffer& buffer) override; + + SocketType createNewSocket(asio::io_context& io_context_); + void shutdownSocket(); + + ConnectionId connection_id_; + std::optional<SocketType> socket_{}; + + std::optional<std::chrono::steady_clock::time_point> last_used_{}; + asio::steady_timer::duration timeout_duration_{}; + + std::shared_ptr<core::logging::Logger> logger_{}; + std::optional<size_t> max_size_of_socket_send_buffer_{}; + + asio::ssl::context* ssl_context_{}; +}; + +template<> +inline TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context& io_context_) { + gsl_Expects(!ssl_context_); + return TcpSocket{io_context_}; +} + +template<> +inline SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context& io_context_) { + gsl_Expects(ssl_context_); + return {io_context_, *ssl_context_}; +} + +template<> +inline void ConnectionHandler<TcpSocket>::shutdownSocket() { +} + +template<> +inline void ConnectionHandler<SslSocket>::shutdownSocket() { + gsl_Expects(ssl_context_); + if (socket_) { + asio::error_code ec; + socket_->lowest_layer().cancel(ec); + if (ec) { + logger_->log_error("Cancelling asynchronous operations of SSL socket failed with: {}", ec.message()); + } + socket_->shutdown(ec); + if (ec) { + logger_->log_error("Shutdown of SSL socket failed with: {}", ec.message()); + } + } +} + +template<class SocketType> +asio::awaitable<std::error_code> ConnectionHandler<SocketType>::establishNewConnection(const asio::ip::tcp::resolver::results_type& endpoints, asio::io_context& io_context) { + auto socket = createNewSocket(io_context); + std::error_code last_error; + for (const auto& endpoint : endpoints) { + auto [connection_error] = co_await asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, use_nothrow_awaitable), timeout_duration_); + if (connection_error) { + logger_->log_debug("Connecting to {} failed due to {}", endpoint.endpoint(), connection_error.message()); + last_error = connection_error; + continue; + } + auto [handshake_error] = co_await handshake(socket, timeout_duration_); + if (handshake_error) { + logger_->log_debug("Handshake with {} failed due to {}", endpoint.endpoint(), handshake_error.message()); + last_error = handshake_error; + continue; + } + if (max_size_of_socket_send_buffer_) + socket.lowest_layer().set_option(TcpSocket::send_buffer_size(gsl::narrow<int>(*max_size_of_socket_send_buffer_))); + socket_.emplace(std::move(socket)); + co_return std::error_code(); + } + co_return last_error; +} + +template<class SocketType> +[[nodiscard]] asio::awaitable<std::error_code> ConnectionHandler<SocketType>::setupUsableSocket(asio::io_context& io_context) { + if (hasUsableSocket()) + co_return std::error_code(); + asio::ip::tcp::resolver resolver(io_context); + auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout( + resolver.async_resolve(connection_id_.getHostname(), connection_id_.getService(), use_nothrow_awaitable), timeout_duration_); + if (resolve_error) + co_return resolve_error; + co_return co_await establishNewConnection(resolve_result, io_context); +} + +template<class SocketType> +asio::awaitable<std::tuple<std::error_code, size_t>> ConnectionHandler<SocketType>::write(const asio::const_buffer& buffer) { + auto result = co_await asyncOperationWithTimeout(asio::async_write(*socket_, buffer, use_nothrow_awaitable), timeout_duration_); + if (!std::get<std::error_code>(result)) { + last_used_ = std::chrono::steady_clock::now(); + } + co_return result; +} + +template<class SocketType> +asio::awaitable<std::tuple<std::error_code, size_t>> ConnectionHandler<SocketType>::read(asio::mutable_buffer& buffer) { + auto result = co_await asyncOperationWithTimeout(asio::async_read(*socket_, buffer, use_nothrow_awaitable), timeout_duration_); + if (!std::get<std::error_code>(result)) { + last_used_ = std::chrono::steady_clock::now(); + } + co_return result; +} + +} // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/include/utils/net/ConnectionHandlerBase.h b/libminifi/include/utils/net/ConnectionHandlerBase.h new file mode 100644 index 000000000..847880bb8 --- /dev/null +++ b/libminifi/include/utils/net/ConnectionHandlerBase.h @@ -0,0 +1,41 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace org::apache::nifi::minifi::utils::net { + +class ConnectionHandlerBase { + public: + ConnectionHandlerBase() = default; + ConnectionHandlerBase(const ConnectionHandlerBase& connection_id) = delete; + ConnectionHandlerBase(ConnectionHandlerBase&& connection_id) = delete; + ConnectionHandlerBase& operator=(ConnectionHandlerBase&&) = delete; + ConnectionHandlerBase& operator=(const ConnectionHandlerBase&) = delete; + + virtual ~ConnectionHandlerBase() = default; + + virtual void reset() = 0; + + [[nodiscard]] virtual asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context) = 0; + [[nodiscard]] virtual bool hasBeenUsed() const = 0; + [[nodiscard]] virtual bool hasBeenUsedIn(std::chrono::milliseconds dur) const = 0; + [[nodiscard]] virtual asio::awaitable<std::tuple<std::error_code, size_t>> write(const asio::const_buffer& buffer) = 0; + [[nodiscard]] virtual asio::awaitable<std::tuple<std::error_code, size_t>> read(asio::mutable_buffer& buffer) = 0; +}; + +} // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/test/libtest/unit/Catch.h b/libminifi/test/libtest/unit/Catch.h index f2cc9d347..6ce6fa833 100644 --- a/libminifi/test/libtest/unit/Catch.h +++ b/libminifi/test/libtest/unit/Catch.h @@ -78,7 +78,7 @@ struct MatchesError : Catch::Matchers::MatcherBase<std::error_code> { bool match(const std::error_code& err) const override { if (expected_error_) - return err == *expected_error_; + return err.value() == expected_error_->value(); return err.value() != 0; } diff --git a/libminifi/test/libtest/unit/TestRecord.h b/libminifi/test/libtest/unit/TestRecord.h index 7c332ad87..aaf98e665 100644 --- a/libminifi/test/libtest/unit/TestRecord.h +++ b/libminifi/test/libtest/unit/TestRecord.h @@ -23,61 +23,72 @@ namespace org::apache::nifi::minifi::core::test { -inline Record createSampleRecord2(const bool stringify_date = false) { +inline Record createSampleRecord2(const bool stringify = false) { using namespace date::literals; // NOLINT(google-build-using-namespace) using namespace std::literals::chrono_literals; Record record; - auto when = date::sys_days(2022_y / 11 / 01) + 19h + 52min + 11s; - if (!stringify_date) { + constexpr auto when = date::sys_days(2022_y / 11 / 01) + 19h + 52min + 11s; + if (!stringify) { record.emplace("when", RecordField{when}); } else { record.emplace("when", RecordField{utils::timeutils::getDateTimeStr(std::chrono::floor<std::chrono::seconds>(when))}); } - record.emplace("foo", RecordField{"Lorem ipsum dolor sit amet, consectetur adipiscing elit."}); + record.emplace("foo", RecordField{std::string{"Lorem ipsum dolor sit amet, consectetur adipiscing elit."}}); record.emplace("bar", RecordField{int64_t{98402134}}); record.emplace("baz", RecordField{std::numbers::pi}); record.emplace("is_test", RecordField{true}); RecordArray qux; - qux.emplace_back(false); - qux.emplace_back(false); - qux.emplace_back(true); + qux.emplace_back('x'); + qux.emplace_back('y'); + qux.emplace_back('z'); + RecordObject quux; - quux["Apfel"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"pomme"})}; - quux["Birne"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"poire"})}; - quux["Aprikose"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"abricot"})}; + quux["Apfel"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"pomme"})}; + quux["Birne"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"poire"})}; + quux["Aprikose"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"abricot"})}; + + RecordArray corge; + corge.emplace_back(false); + corge.emplace_back(false); record.emplace("qux", RecordField{std::move(qux)}); record.emplace("quux", RecordField{std::move(quux)}); + record.emplace("corge", RecordField{std::move(corge)}); return record; } -inline Record createSampleRecord(const bool stringify_date = false) { +inline Record createSampleRecord(const bool stringify = false) { using namespace date::literals; // NOLINT(google-build-using-namespace) using namespace std::literals::chrono_literals; Record record; - auto when = date::sys_days(2012_y / 07 / 01) + 9h + 53min + 00s; - if (!stringify_date) { + constexpr auto when = date::sys_days(2012_y / 07 / 01) + 9h + 53min + 00s; + if (!stringify) { record.emplace("when", RecordField{when}); } else { record.emplace("when", RecordField{utils::timeutils::getDateTimeStr(std::chrono::floor<std::chrono::seconds>(when))}); } - record.emplace("foo", RecordField{"asd"}); + record.emplace("foo", RecordField{std::string{"asd"}}); record.emplace("bar", RecordField{int64_t{123}}); record.emplace("baz", RecordField{3.14}); record.emplace("is_test", RecordField{true}); RecordArray qux; - qux.emplace_back(true); - qux.emplace_back(false); - qux.emplace_back(true); + qux.emplace_back('a'); + qux.emplace_back('b'); + qux.emplace_back('c'); RecordObject quux; - quux["Apfel"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"apple"})}; - quux["Birne"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"pear"})}; - quux["Aprikose"] = BoxedRecordField{std::make_unique<RecordField>(RecordField{"apricot"})}; + quux["Apfel"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"apple"})}; + quux["Birne"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"pear"})}; + quux["Aprikose"] = BoxedRecordField{std::make_unique<RecordField>(std::string{"apricot"})}; + + RecordArray corge; + corge.emplace_back(true); + corge.emplace_back(false); record.emplace("qux", RecordField{std::move(qux)}); record.emplace("quux", RecordField{std::move(quux)}); + record.emplace("corge", RecordField{std::move(corge)}); return record; }
