/*

Copyright 2021 Bryan Rosander

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

*/

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

/**
 * Removes JndiLookup.class from all jars embedded in nars
 */
public class Log4jPatch {
    private byte[] buf = new byte[1_000_000];

    void copy(InputStream is, OutputStream os) throws IOException {
        int length;
        while ((length = is.read(buf)) > 0) {
            os.write(buf, 0, length);
        }
    }

    boolean patchJar(ZipInputStream jarInput, ZipOutputStream jarOutput, String prefix) throws IOException {
        ZipEntry jarEntry = jarInput.getNextEntry();
        boolean patched = false;
        while (jarEntry != null) {
            if (jarEntry.getName().contains("log4j") && jarEntry.getName().contains("JndiLookup.class")) {
                System.out.println("Omitting " + prefix + ":" + jarEntry.getName() + " from patched output.");
                patched = true;
            } else {
                jarOutput.putNextEntry(new ZipEntry(jarEntry.getName()));
                copy(jarInput, jarOutput);
                jarOutput.closeEntry();
            }
            jarEntry = jarInput.getNextEntry();
        }
        return patched;
    }

    boolean shouldPatch(Path nar) throws IOException {
        return new ZipFile(nar.toFile())
                .stream()
                .filter(f -> f.getName().toLowerCase().contains("log4j") && f.getName().toLowerCase().contains("core")).findFirst().isPresent();
    }

    void patchNars(String path) throws IOException {
        List<Path> nars = Files.walk(Paths.get(path)).filter(f -> f.getFileName().toString().endsWith(".nar")).collect(Collectors.toList());
        for (Path nar : nars) {
            System.out.println("Processing " + nar);
            if (!shouldPatch(nar)) {
                continue;
            }

            Path patchedNar = Files.createTempFile(nar.getFileName().toString(), "tmp");

            ZipOutputStream narOutput = new ZipOutputStream(Files.newOutputStream(patchedNar));
            ZipInputStream narInput = new ZipInputStream(Files.newInputStream(nar));
            ZipEntry narEntry = narInput.getNextEntry();

            boolean patched = false;
            while (narEntry != null) {
                narOutput.putNextEntry(new ZipEntry(narEntry.getName()));
                if (narEntry.getName().toLowerCase().endsWith(".jar")) {
                    ZipInputStream jarInput = new ZipInputStream(narInput);
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    ZipOutputStream jarOutput = new ZipOutputStream(byteArrayOutputStream);

                    patched |= patchJar(jarInput, jarOutput, nar + ":" + narEntry.getName());

                    jarOutput.close();

                    narOutput.write(byteArrayOutputStream.toByteArray());
                } else {
                    copy(narInput, narOutput);
                }
                narOutput.closeEntry();
                narEntry = narInput.getNextEntry();
            }
            narInput.close();
            narOutput.close();

            if (patched) {
                System.out.println("Patched " + nar);
                Files.delete(nar);
                Files.move(patchedNar, nar);
            } else {
                Files.delete(patchedNar);
            }
        }
    }

    public static void main(String[] args) {
        if (args.length < 1) {
            System.err.println("Expected a single argument of nifi path");
        }
        System.out.println("Patching nifi install at: " + args[0]);
        try {
            new Log4jPatch().patchNars(args[0]);
        } catch (Exception e) {
            System.err.println("Unable to patch nifi");
            e.printStackTrace();
            System.exit(1);
        }
    }
}

