001/**
002 * Copyright 2010-2013 The Kuali Foundation
003 *
004 * Licensed under the Educational Community License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.opensource.org/licenses/ecl2.php
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.kuali.common.util.channel.impl;
017
018import java.io.BufferedOutputStream;
019import java.io.ByteArrayInputStream;
020import java.io.ByteArrayOutputStream;
021import java.io.File;
022import java.io.IOException;
023import java.io.InputStream;
024import java.io.OutputStream;
025import java.util.ArrayList;
026import java.util.List;
027
028import org.apache.commons.io.FileUtils;
029import org.apache.commons.io.FilenameUtils;
030import org.apache.commons.io.IOUtils;
031import org.apache.commons.lang3.StringUtils;
032import org.kuali.common.util.Assert;
033import org.kuali.common.util.CollectionUtils;
034import org.kuali.common.util.FormatUtils;
035import org.kuali.common.util.LocationUtils;
036import org.kuali.common.util.PropertyUtils;
037import org.kuali.common.util.Str;
038import org.kuali.common.util.base.Threads;
039import org.kuali.common.util.channel.api.SecureChannel;
040import org.kuali.common.util.channel.model.ChannelContext;
041import org.kuali.common.util.channel.model.CommandContext;
042import org.kuali.common.util.channel.model.CommandResult;
043import org.kuali.common.util.channel.model.CopyDirection;
044import org.kuali.common.util.channel.model.CopyResult;
045import org.kuali.common.util.channel.model.RemoteFile;
046import org.kuali.common.util.channel.model.Status;
047import org.kuali.common.util.channel.util.ChannelUtils;
048import org.kuali.common.util.channel.util.SSHUtils;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052import com.google.common.base.Optional;
053import com.google.common.collect.ImmutableList;
054import com.jcraft.jsch.Channel;
055import com.jcraft.jsch.ChannelExec;
056import com.jcraft.jsch.ChannelSftp;
057import com.jcraft.jsch.JSch;
058import com.jcraft.jsch.JSchException;
059import com.jcraft.jsch.Session;
060import com.jcraft.jsch.SftpATTRS;
061import com.jcraft.jsch.SftpException;
062
063public final class DefaultSecureChannel implements SecureChannel {
064
065        private static final Logger logger = LoggerFactory.getLogger(DefaultSecureChannel.class);
066
067        private static final String SFTP = "sftp";
068        private static final String EXEC = "exec";
069        private static final String FORWARDSLASH = "/";
070
071        private final Session session;
072        private final ChannelSftp sftp;
073        private final ChannelContext context;
074
075        private boolean closed = false;
076
077        public DefaultSecureChannel(ChannelContext context) throws IOException {
078                Assert.noNulls(context);
079                this.context = context;
080                log();
081                try {
082                        JSch jsch = getJSch();
083                        this.session = openSession(jsch);
084                        this.sftp = openSftpChannel(session, context.getConnectTimeout());
085                } catch (JSchException e) {
086                        throw new IOException("Unexpected error opening secure channel", e);
087                }
088        }
089
090        @Override
091        public synchronized void close() {
092                if (closed) {
093                        return;
094                }
095                if (context.isEcho()) {
096                        logger.info("Closing secure channel [{}]", ChannelUtils.getLocation(context.getUsername(), context.getHostname()));
097                } else {
098                        logger.debug("Closing secure channel [{}]", ChannelUtils.getLocation(context.getUsername(), context.getHostname()));
099                }
100                closeQuietly(sftp);
101                closeQuietly(session);
102                this.closed = true;
103        }
104
105        @Override
106        public List<CommandResult> exec(String... commands) {
107                List<CommandResult> results = new ArrayList<CommandResult>();
108                List<String> copy = ImmutableList.copyOf(commands);
109                for (String command : copy) {
110                        CommandResult result = exec(command);
111                        results.add(result);
112                }
113                return results;
114        }
115
116        @Override
117        public List<CommandResult> exec(CommandContext... contexts) {
118                List<CommandResult> results = new ArrayList<CommandResult>();
119                List<CommandContext> copy = ImmutableList.copyOf(contexts);
120                for (CommandContext context : copy) {
121                        CommandResult result = exec(context);
122                        results.add(result);
123                }
124                return results;
125        }
126
127        @Override
128        public CommandResult exec(String command) {
129                return exec(new CommandContext.Builder(command).build());
130        }
131
132        @Override
133        public CommandResult exec(CommandContext context) {
134                StreamHandler handler = new StreamHandler(context);
135                ChannelExec exec = null;
136                try {
137                        // Preserve start time
138                        long start = System.currentTimeMillis();
139                        // Open an exec channel
140                        exec = getChannelExec();
141                        // Convert the command string to a byte array and store it on the exec channel
142                        exec.setCommand(context.getCommand());
143                        // Update the ChannelExec object with the stdin stream
144                        exec.setInputStream(context.getStdin().orNull());
145                        // Setup handling of stdin, stdout, and stderr
146                        handler.openStreams(exec, this.context.getEncoding());
147                        // Get ready to consume anything on stdin, and pump stdout/stderr back out to the consumers
148                        handler.startPumping();
149                        // This invokes the command on the remote system, consumes whatever is on stdin, and produces output to stdout/stderr
150                        connect(exec, context.getTimeout());
151                        // Wait until the channel reaches the "closed" state
152                        waitForClosed(exec, this.context.getWaitForClosedSleepMillis());
153                        // Wait for the streams to finish up
154                        handler.waitUntilDone();
155                        // Make sure there were no exceptions
156                        handler.validate();
157                        // Construct a result object
158                        CommandResult result = new CommandResult(context.getCommand(), exec.getExitStatus(), start);
159                        // Validate that things turned out ok (or that we don't care)
160                        validate(context, result);
161                        // Echo the command, if requested
162                        if (this.context.isEcho()) {
163                                String elapsed = FormatUtils.getTime(result.getElapsed());
164                                logger.info("{} - [{}]", new String(context.getCommand(), this.context.getEncoding()), elapsed);
165                        }
166                        // Return the result
167                        return result;
168                } catch (Exception e) {
169                        // Make sure the streams are disabled
170                        handler.disableQuietly();
171                        throw new IllegalStateException(e);
172                } finally {
173                        // Clean everything up
174                        IOUtils.closeQuietly(context.getStdin().orNull());
175                        closeQuietly(exec);
176                        handler.closeQuietly();
177                }
178        }
179
180        protected void validate(CommandContext context, CommandResult result) {
181                if (context.isIgnoreExitValue()) {
182                        return;
183                }
184                if (context.getSuccessCodes().size() == 0) {
185                        return;
186                }
187                List<Integer> codes = context.getSuccessCodes();
188                int exitValue = result.getExitValue();
189                for (int successCode : codes) {
190                        if (exitValue == successCode) {
191                                return;
192                        }
193                }
194                throw new IllegalStateException("Command exited with [" + exitValue + "].  Valid values are [" + CollectionUtils.toCSV(codes) + "]");
195        }
196
197        protected ChannelExec getChannelExec() throws JSchException {
198                ChannelExec exec = (ChannelExec) session.openChannel(EXEC);
199                if (context.isRequestPseudoTerminal()) {
200                        exec.setPty(true);
201                }
202                return exec;
203        }
204
205        @Override
206        public void execNoWait(String command) {
207                execNoWait(Str.getBytes(command, context.getEncoding()));
208        }
209
210        @Override
211        public void execNoWait(byte[] command) {
212                Assert.noNulls(command);
213                ChannelExec exec = null;
214                try {
215                        if (context.isEcho()) {
216                                logger.info("{}", Str.getString(command, context.getEncoding()));
217                        }
218                        // Open an exec channel
219                        exec = getChannelExec();
220                        // Store the command on the exec channel
221                        exec.setCommand(command);
222                        // Execute the command.
223                        // This consumes anything from stdin and stores output in stdout/stderr
224                        connect(exec, Optional.<Integer> absent());
225                } catch (Exception e) {
226                        throw new IllegalStateException(e);
227                } finally {
228                        closeQuietly(exec);
229                }
230        }
231
232        protected void waitForClosed(ChannelExec exec, long millis) {
233                while (!exec.isClosed()) {
234                        Threads.sleep(millis);
235                }
236        }
237
238        @Override
239        public RemoteFile getWorkingDirectory() {
240                try {
241                        String workingDirectory = sftp.pwd();
242                        return getMetaData(workingDirectory);
243                } catch (SftpException e) {
244                        throw new IllegalStateException(e);
245                }
246        }
247
248        protected void log() {
249                if (context.isEcho()) {
250                        logger.info("Opening secure channel [{}] encoding={}", ChannelUtils.getLocation(context.getUsername(), context.getHostname()), context.getEncoding());
251                } else {
252                        logger.debug("Opening secure channel [{}] encoding={}", ChannelUtils.getLocation(context.getUsername(), context.getHostname()), context.getEncoding());
253                }
254                logger.debug("Private key files - {}", context.getPrivateKeyFiles().size());
255                logger.debug("Private key strings - {}", context.getPrivateKeys().size());
256                logger.debug("Private key config file - {}", context.getConfig());
257                logger.debug("Private key config file use - {}", context.isUseConfigFile());
258                logger.debug("Include default private key locations - {}", context.isIncludeDefaultPrivateKeyLocations());
259                logger.debug("Known hosts file - {}", context.getKnownHosts());
260                logger.debug("Port - {}", context.getPort());
261                if (context.getConnectTimeout().isPresent()) {
262                        logger.debug("Connect timeout - {}", context.getConnectTimeout().get());
263                }
264                logger.debug("Strict host key checking - {}", context.isStrictHostKeyChecking());
265                logger.debug("Configuring channel with {} custom options", context.getOptions().size());
266                PropertyUtils.debug(context.getOptions());
267        }
268
269        protected ChannelSftp openSftpChannel(Session session, Optional<Integer> timeout) throws JSchException {
270                ChannelSftp sftp = (ChannelSftp) session.openChannel(SFTP);
271                connect(sftp, timeout);
272                return sftp;
273        }
274
275        protected void connect(Channel channel, Optional<Integer> timeout) throws JSchException {
276                if (timeout.isPresent()) {
277                        channel.connect(timeout.get());
278                } else {
279                        channel.connect();
280                }
281        }
282
283        protected void closeQuietly(Session session) {
284                if (session != null) {
285                        session.disconnect();
286                }
287        }
288
289        protected void closeQuietly(Channel channel) {
290                if (channel != null) {
291                        channel.disconnect();
292                }
293        }
294
295        protected Session openSession(JSch jsch) throws JSchException {
296                Session session = jsch.getSession(context.getUsername().orNull(), context.getHostname(), context.getPort());
297
298                session.setConfig(context.getOptions());
299                if (context.getConnectTimeout().isPresent()) {
300                        session.connect(context.getConnectTimeout().get());
301                } else {
302                        session.connect();
303                }
304                return session;
305        }
306
307        protected JSch getJSch() {
308                try {
309                        JSch jsch = getJSch(context.getPrivateKeyFiles(), context.getPrivateKeys());
310                        File knownHosts = context.getKnownHosts();
311                        if (context.isUseKnownHosts() && knownHosts.exists()) {
312                                String path = LocationUtils.getCanonicalPath(knownHosts);
313                                jsch.setKnownHosts(path);
314                        }
315                        return jsch;
316                } catch (JSchException e) {
317                        throw new IllegalStateException("Unexpected error", e);
318                }
319        }
320
321        protected JSch getJSch(List<File> privateKeys, List<String> privateKeyStrings) throws JSchException {
322                JSch jsch = new JSch();
323                for (File privateKey : privateKeys) {
324                        String path = LocationUtils.getCanonicalPath(privateKey);
325                        jsch.addIdentity(path);
326                }
327                int count = 0;
328                for (String privateKeyString : privateKeyStrings) {
329                        String name = "privateKeyString-" + Integer.toString(count++);
330                        byte[] bytes = Str.getBytes(privateKeyString, context.getEncoding());
331                        jsch.addIdentity(name, bytes, null, null);
332                }
333                return jsch;
334        }
335
336        protected static List<File> getUniquePrivateKeyFiles(List<File> privateKeys, boolean useConfigFile, File config, boolean includeDefaultPrivateKeyLocations) {
337                List<String> paths = new ArrayList<String>();
338                for (File privateKey : privateKeys) {
339                        paths.add(LocationUtils.getCanonicalPath(privateKey));
340                }
341                if (useConfigFile) {
342                        for (String path : SSHUtils.getFilenames(config)) {
343                                paths.add(path);
344                        }
345                }
346                if (includeDefaultPrivateKeyLocations) {
347                        for (String path : SSHUtils.PRIVATE_KEY_DEFAULTS) {
348                                paths.add(path);
349                        }
350                }
351                List<String> uniquePaths = CollectionUtils.getUniqueStrings(paths);
352                return SSHUtils.getExistingAndReadable(uniquePaths);
353        }
354
355        @Override
356        public RemoteFile getMetaData(String absolutePath) {
357                Assert.noBlanks(absolutePath);
358                return fillInAttributes(absolutePath);
359        }
360
361        @Override
362        public void deleteFile(String absolutePath) {
363                RemoteFile file = getMetaData(absolutePath);
364                if (isStatus(file, Status.MISSING)) {
365                        return;
366                }
367                if (file.isDirectory()) {
368                        throw new IllegalArgumentException("[" + ChannelUtils.getLocation(context.getUsername(), context.getHostname(), file) + "] is a directory.");
369                }
370                try {
371                        sftp.rm(absolutePath);
372                        if (context.isEcho()) {
373                                logger.info("deleted -> [{}]", absolutePath);
374                        }
375                } catch (SftpException e) {
376                        throw new IllegalStateException(e);
377                }
378        }
379
380        @Override
381        public boolean exists(String absolutePath) {
382                RemoteFile file = getMetaData(absolutePath);
383                return isStatus(file, Status.EXISTS);
384        }
385
386        @Override
387        public boolean isDirectory(String absolutePath) {
388                RemoteFile file = getMetaData(absolutePath);
389                return isStatus(file, Status.EXISTS) && file.isDirectory();
390        }
391
392        protected RemoteFile fillInAttributes(String path) {
393                try {
394                        SftpATTRS attributes = sftp.stat(path);
395                        return fillInAttributes(path, attributes);
396                } catch (SftpException e) {
397                        return handleNoSuchFileException(path, e);
398                }
399        }
400
401        protected RemoteFile fillInAttributes(String path, SftpATTRS attributes) {
402                boolean directory = attributes.isDir();
403                int permissions = attributes.getPermissions();
404                int userId = attributes.getUId();
405                int groupId = attributes.getGId();
406                long size = attributes.getSize();
407                Status status = Status.EXISTS;
408                return new RemoteFile.Builder(path).directory(directory).permissions(permissions).userId(userId).groupId(groupId).size(size).status(status).build();
409        }
410
411        @Override
412        public CopyResult scp(File source, RemoteFile destination) {
413                Assert.notNull(source);
414                Assert.exists(source);
415                Assert.isFalse(source.isDirectory(), "[" + source + "] is a directory");
416                Assert.isTrue(source.canRead(), "[" + source + "] not readable");
417                return scp(LocationUtils.getCanonicalURLString(source), destination);
418        }
419
420        @Override
421        public CopyResult scpToDir(File source, RemoteFile directory) {
422                String filename = source.getName();
423                String absolutePath = getAbsolutePath(directory.getAbsolutePath(), filename);
424                RemoteFile file = new RemoteFile.Builder(absolutePath).clone(directory).build();
425                return scp(source, file);
426        }
427
428        @Override
429        public CopyResult scp(String location, RemoteFile destination) {
430                Assert.notNull(location);
431                Assert.isTrue(LocationUtils.exists(location), location + " does not exist");
432                InputStream in = null;
433                try {
434                        in = LocationUtils.getInputStream(location);
435                        return scp(in, destination);
436                } catch (Exception e) {
437                        throw new IllegalStateException(e);
438                } finally {
439                        IOUtils.closeQuietly(in);
440                }
441        }
442
443        @Override
444        public CopyResult scpString(String string, RemoteFile destination) {
445                Assert.notNull(string);
446                InputStream in = new ByteArrayInputStream(Str.getBytes(string, context.getEncoding()));
447                CopyResult result = scp(in, destination);
448                IOUtils.closeQuietly(in);
449                return result;
450        }
451
452        @Override
453        public String toString(RemoteFile source) {
454                Assert.notNull(source);
455                ByteArrayOutputStream out = new ByteArrayOutputStream();
456                try {
457                        scp(source, out);
458                        return out.toString(context.getEncoding());
459                } catch (IOException e) {
460                        throw new IllegalStateException("Unexpected IO error", e);
461                } finally {
462                        IOUtils.closeQuietly(out);
463                }
464        }
465
466        @Override
467        public CopyResult scp(InputStream source, RemoteFile destination) {
468                Assert.notNull(source);
469                try {
470                        long start = System.currentTimeMillis();
471                        createDirectories(destination);
472                        sftp.put(source, destination.getAbsolutePath());
473                        RemoteFile meta = getMetaData(destination.getAbsolutePath());
474                        CopyResult result = new CopyResult(start, meta.getSize().get(), CopyDirection.TO_REMOTE);
475                        to(destination, result);
476                        return result;
477                } catch (SftpException e) {
478                        throw new IllegalStateException(e);
479                }
480        }
481
482        protected String getAbsolutePath(String absolutePath, String filename) {
483                if (StringUtils.endsWith(absolutePath, FORWARDSLASH)) {
484                        return absolutePath + filename;
485                } else {
486                        return absolutePath + FORWARDSLASH + filename;
487                }
488        }
489
490        @Override
491        public CopyResult scpToDir(String location, RemoteFile directory) {
492                String filename = LocationUtils.getFilename(location);
493                String absolutePath = getAbsolutePath(directory.getAbsolutePath(), filename);
494                RemoteFile file = new RemoteFile.Builder(absolutePath).clone(directory).build();
495                return scp(location, file);
496        }
497
498        @Override
499        public CopyResult scp(RemoteFile source, File destination) {
500                OutputStream out = null;
501                try {
502                        out = new BufferedOutputStream(FileUtils.openOutputStream(destination));
503                        return scp(source, out);
504                } catch (Exception e) {
505                        throw new IllegalStateException(e);
506                } finally {
507                        IOUtils.closeQuietly(out);
508                }
509        }
510
511        @Override
512        public CopyResult scp(String absolutePath, OutputStream out) throws IOException {
513                try {
514                        long start = System.currentTimeMillis();
515                        sftp.get(absolutePath, out);
516                        RemoteFile meta = getMetaData(absolutePath);
517                        CopyResult result = new CopyResult(start, meta.getSize().get(), CopyDirection.FROM_REMOTE);
518                        from(absolutePath, result);
519                        return result;
520                } catch (SftpException e) {
521                        throw new IOException("Unexpected IO error", e);
522                }
523        }
524
525        /**
526         * Show information about the transfer of data to a remote server
527         */
528        protected void to(RemoteFile destination, CopyResult result) {
529                if (context.isEcho()) {
530                        String elapsed = FormatUtils.getTime(result.getElapsedMillis());
531                        String rate = FormatUtils.getRate(result.getElapsedMillis(), result.getAmountInBytes());
532                        Object[] args = { destination.getAbsolutePath(), elapsed, rate };
533                        logger.info("created -> [{}] - [{}, {}]", args);
534                }
535        }
536
537        /**
538         * Show information about the transfer of data from a remote server
539         */
540        protected void from(String absolutePath, CopyResult result) {
541                if (context.isEcho()) {
542                        String elapsed = FormatUtils.getTime(result.getElapsedMillis());
543                        String rate = FormatUtils.getRate(result.getElapsedMillis(), result.getAmountInBytes());
544                        Object[] args = { absolutePath, elapsed, rate };
545                        logger.info("copied <- [{}] - [{}, {}]", args);
546                }
547        }
548
549        @Override
550        public CopyResult scp(RemoteFile source, OutputStream out) throws IOException {
551                return scp(source.getAbsolutePath(), out);
552        }
553
554        @Override
555        public CopyResult scpToDir(RemoteFile source, File destination) {
556                String filename = FilenameUtils.getName(source.getAbsolutePath());
557                File newDestination = new File(destination, filename);
558                return scp(source, newDestination);
559        }
560
561        @Override
562        public void createDirectory(RemoteFile dir) {
563                Assert.isTrue(dir.isDirectory());
564                try {
565                        createDirectories(dir);
566                        if (context.isEcho()) {
567                                logger.info("mkdir -> [{}]", dir.getAbsolutePath());
568                        }
569                } catch (SftpException e) {
570                        throw new IllegalStateException(e);
571                }
572        }
573
574        protected void createDirectories(RemoteFile file) throws SftpException {
575                boolean directoryIndicator = file.isDirectory();
576                RemoteFile remoteFile = fillInAttributes(file.getAbsolutePath());
577                validate(remoteFile, directoryIndicator);
578                List<String> directories = LocationUtils.getNormalizedPathFragments(file.getAbsolutePath(), file.isDirectory());
579                for (String directory : directories) {
580                        RemoteFile parentDir = fillInAttributes(directory);
581                        validate(parentDir, true);
582                        if (!isStatus(parentDir, Status.EXISTS)) {
583                                mkdir(parentDir);
584                        }
585                }
586        }
587
588        protected boolean isStatus(RemoteFile file, Status status) {
589                Optional<Status> remoteStatus = file.getStatus();
590                if (remoteStatus.isPresent()) {
591                        return remoteStatus.get().equals(status);
592                } else {
593                        return false;
594                }
595        }
596
597        protected void validate(RemoteFile file, boolean directoryIndicator) {
598                // Make sure status has been filled in
599                Assert.isTrue(file.getStatus().isPresent());
600
601                // Convenience flags
602                boolean missing = isStatus(file, Status.MISSING);
603                boolean exists = isStatus(file, Status.EXISTS);
604
605                // It it is supposed to be a directory, make sure it's a directory
606                // If it is supposed to be a regular file, make sure it's a regular file
607                boolean correctFileType = file.isDirectory() == directoryIndicator;
608
609                // Is everything as it should be?
610                boolean valid = missing || exists && correctFileType;
611
612                Assert.isTrue(valid, getInvalidExistingFileMessage(file));
613        }
614
615        protected String getInvalidExistingFileMessage(RemoteFile existing) {
616                if (existing.isDirectory()) {
617                        return "[" + ChannelUtils.getLocation(context.getUsername(), context.getHostname(), existing) + "] is an existing directory. Unable to create file.";
618                } else {
619                        return "[" + ChannelUtils.getLocation(context.getUsername(), context.getHostname(), existing) + "] is an existing file. Unable to create directory.";
620                }
621        }
622
623        protected void mkdir(RemoteFile dir) {
624                try {
625                        String path = dir.getAbsolutePath();
626                        logger.debug("Creating [{}]", path);
627                        sftp.mkdir(path);
628                        setAttributes(dir);
629                } catch (SftpException e) {
630                        throw new IllegalStateException(e);
631                }
632        }
633
634        protected void setAttributes(RemoteFile file) throws SftpException {
635                String path = file.getAbsolutePath();
636                if (file.getPermissions().isPresent()) {
637                        sftp.chmod(file.getPermissions().get(), path);
638                }
639                if (file.getGroupId().isPresent()) {
640                        sftp.chgrp(file.getGroupId().get(), path);
641                }
642                if (file.getUserId().isPresent()) {
643                        sftp.chown(file.getUserId().get(), path);
644                }
645        }
646
647        protected RemoteFile handleNoSuchFileException(String path, SftpException e) {
648                if (isNoSuchFileException(e)) {
649                        return new RemoteFile.Builder(path).status(Status.MISSING).build();
650                } else {
651                        throw new IllegalStateException(e);
652                }
653        }
654
655        protected boolean isNoSuchFileException(SftpException exception) {
656                return exception.id == ChannelSftp.SSH_FX_NO_SUCH_FILE;
657        }
658
659}