/*
 * pam_sshauth: PAM module for authentication via a remote ssh server.
 * Copyright (C) 2010-2013 Scott Balneaves <sbalneav@ltsp.org>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <syslog.h>
#include <config.h>
#include <libssh2.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <sys/select.h>
#include <unistd.h>
#include <arpa/inet.h>

#include <security/pam_modules.h>
#include <security/pam_ext.h>

#include "pam_sshauth.h"

#define SSH_AUTH_METHOD_PASSWORD 1
#define SSH_AUTH_METHOD_INTERACTIVE 2

#define SSH_AUTH_SUCCESS 0
#define SSH_AUTH_ERROR 1

static void
kbd_callback (const char *name, int name_len,
	      const char *instruction, int instruction_len, int num_prompts,
	      const LIBSSH2_USERAUTH_KBDINT_PROMPT * prompts,
	      LIBSSH2_USERAUTH_KBDINT_RESPONSE * responses, void **abstract)
{
  pam_handle_t *pamh = *abstract;
  int i;

  /*
   * Get any instructions the ssh session has generted,
   * and send them to the user via the pam message system.
   */

  if (instruction_len > 0)
    {
      send_pam_msg (pamh, PAM_TEXT_INFO, instruction);
    }

  /*
   * Loop through the prompts that ssh has given us, and ask the
   * user via pam prompts for the answers.
   */

  for (i = 0; i < num_prompts; i++)
    {
      int style = prompts[i].echo ? PAM_PROMPT_ECHO_ON : PAM_PROMPT_ECHO_OFF;
      int pam_retval;
      char *buf, *response;

      if ((buf = malloc (prompts[i].length + 1)) == NULL)
        {
          return;
        }
      strncpy (buf, prompts[i].text, prompts[i].length);
      *(buf + prompts[i].length) = '\0';
      pam_retval = pam_prompt (pamh, style, &response, "%s", buf);
      free (buf);
      if (pam_retval != PAM_SUCCESS)
	{
	  return;
	}

      responses[i].text = strdup (response);
      responses[i].length = strlen (response);

      if (pam_set_item (pamh, PAM_AUTHTOK, response) != PAM_SUCCESS)
	{
	  return;
	}
    }
}

/*
 * auth_pw ()
 *
 * conduct an ssh simple password based authentication
 */

static int
auth_pw (pam_handle_t * pamh, const char *username, LIBSSH2_SESSION * session)
{
  int ssh_result;
  char *password = NULL;

  /*
   * try_first_pass works with simple password authentication.
   */

  if (try_first_pass)
    {
      if (pam_get_item (pamh, PAM_AUTHTOK, (const void **) &password) !=
	  PAM_SUCCESS)
	{
	  pam_syslog (pamh, LOG_ERR,
		      "Couldn't obtain PAM_AUTHTOK from the pam stack.");
	  password = NULL;
	}
    }

  if (password == NULL)
    {
      if (pam_prompt (pamh, PAM_PROMPT_ECHO_OFF, &password, "Password:") !=
	  PAM_SUCCESS)
	{
	  pam_syslog (pamh, LOG_ERR,
		      "Couldn't obtain password from pam_prompt.");
	  return SSH_AUTH_ERROR;
	}
    }

  ssh_result = libssh2_userauth_password (session, username, password);
  if (ssh_result == SSH_AUTH_SUCCESS)
    {
      /*
       * The very last response we've gotten should be the password.  Store it
       * as the AUTHTOK
       */

      if (!try_first_pass
	  && pam_set_item (pamh, PAM_AUTHTOK, password) != PAM_SUCCESS)
	{
	  pam_syslog (pamh, LOG_ERR,
		      "Couldn't store password as PAM_AUTHTOK.");
	  return SSH_AUTH_ERROR;
	}

      return ssh_result;
    }
  else
    {
      char *errmsg;
      int len;

      libssh2_session_last_error (session, &errmsg, &len, 0);

      send_pam_msg (pamh, PAM_TEXT_INFO, errmsg);
      return ssh_result;
    }
}

/*
 * do_sshauth ()
 *
 * Authenticate by attempting an ssh connection
 */

int
do_sshauth (pam_handle_t * pamh, const char *username)
{
  int method = 0;
  int ssh_result = SSH_AUTH_ERROR;
  int pam_result;
  int count;
  int iport;
  int sockfd = 0;
  int type;
  const char *fingerprint = NULL;
  const char *host;
  const char *port;
  char *userauthlist;
  struct hostent *server;
  struct sockaddr_in serv_addr;
  LIBSSH2_SESSION *session = NULL;
  LIBSSH2_KNOWNHOSTS *nh = NULL;
  size_t len;
  FILE *khf;

  pam_result = pam_get_data (pamh, HOST, (const void **) &host);
  if (pam_result != PAM_SUCCESS)
    {
      pam_syslog (pamh, LOG_ERR,
		  "Couldn't retrieve hostname from pam handle.");
      return pam_result;
    }

  pam_result = pam_get_data (pamh, PORT, (const void **) &port);
  if (pam_result != PAM_SUCCESS)
    {
      /* Couldn't retrieve port.  Fallback to 22 */
      iport = 22;
    }
  else
    {
      iport = atoi (port);
    }

  /*
   * Establish our socket
   */

  sockfd = socket (AF_INET, SOCK_STREAM, 0);
  if (sockfd < 0)
    {
      pam_syslog (pamh, LOG_ERR, "Couldn't create socket.");
      return PAM_SYSTEM_ERR;
    }

  server = gethostbyname (host);
  if (server == NULL)
    {
      pam_syslog (pamh, LOG_ERR, "Couldn't resolve hostname %s.", host);
      return PAM_SYSTEM_ERR;
    }
  bzero ((char *) &serv_addr, sizeof serv_addr);
  serv_addr.sin_family = AF_INET;
  bcopy ((char *) server->h_addr,
	 (char *) &serv_addr.sin_addr.s_addr, server->h_length);
  serv_addr.sin_port = htons (iport);
  if (connect (sockfd, (struct sockaddr *) &serv_addr, sizeof serv_addr) < 0)
    {
      pam_syslog (pamh, LOG_ERR, "Couldn't connect to %s.", host);
      return PAM_SYSTEM_ERR;
    }

  /*
   * Begin the authentication loop.  Loop until we're successfully
   * authenticated, or AUTHTRIES times, whichever comes first.
   */

  count = authtries;

  session = libssh2_session_init_ex (NULL, NULL, NULL, (void *) pamh);

  if (session == NULL)
    {
      pam_syslog (pamh, LOG_ERR,
		  "Couldn't allocate ssh session structure for host %s",
		  host);
      goto fail;
    }

  /*
   * tell libssh2 we want communications to use blocking
   */

  libssh2_session_set_blocking (session, 1);

  /*
   * Start the connection.
   */

  if (libssh2_session_handshake (session, sockfd) != 0)
    {
      pam_syslog (pamh, LOG_ERR, "Couldn't handshake ssh session.");
      goto fail;
    }

  nh = libssh2_knownhost_init (session);
  if (!nh)
    {
      pam_syslog (pamh, LOG_ERR, "Couldn't allocate known_host structure.");
      goto fail;
    }

  pam_debug (pamh, "Connected to host %s", host);

  /*
   * Load known hosts from the system file.
   */

  khf = fopen (SYSTEM_KNOWNHOSTS, "r");

  if (khf)
    {
      char buf[2048];

      while (fgets (buf, sizeof (buf), khf))
        {
          if (libssh2_knownhost_readline (nh, buf, strlen (buf), LIBSSH2_KNOWNHOST_FILE_OPENSSH))
            {
              /* Skip over any keys libssh2 doesn't support */
              continue;
            }
        }
      fclose (khf);
    }

  /*
   * Obtain our hosts fingerprint.
   */

  fingerprint = libssh2_session_hostkey (session, &len, &type);

  /*
   * Is this server known to us? Try our /etc/ssh/ssh_known_hosts file
   * first.
   */

  if (fingerprint)
    {
      int check = libssh2_knownhost_check (nh, host, fingerprint, len,
					    LIBSSH2_KNOWNHOST_TYPE_PLAIN |
					    LIBSSH2_KNOWNHOST_KEYENC_RAW,
					    NULL);

      if (check == LIBSSH2_KNOWNHOST_CHECK_FAILURE)
	{
	  goto fail;
	}

      if ((check == LIBSSH2_KNOWNHOST_CHECK_NOTFOUND) && nostrict)
	{
          char *response = NULL;
	  /*
	   * Unknown server.  Ask the user, via the pam prompts, if they'd
	   * like to connect.
	   */
	  pam_debug (pamh, "Server not in known_hosts file.");
	  send_pam_msg (pamh, PAM_TEXT_INFO, "Server unknown. Trust?");
	  pam_prompt (pamh, PAM_PROMPT_ECHO_ON, &response,
		      "Type 'yes' to continue: ");
	  if (!response || (strncasecmp (response, "yes", 3) != 0))
	    {
	      pam_debug (pamh, "User does not trust host.");
	      goto fail;
	    }
	}
      else if ((check == LIBSSH2_KNOWNHOST_CHECK_NOTFOUND) && !nostrict)
	{
	  pam_debug (pamh,
		     "Unknown server and strict host checking.  Connection denied.");
	  send_pam_msg (pamh, PAM_TEXT_INFO,
			"Server is unknown, and strict checking enabled.  Connection denied.");
	  goto fail;
	}
      else if (check == LIBSSH2_KNOWNHOST_CHECK_MISMATCH)
	{
	  pam_debug (pamh, "Host key for %s changed. Connection terminated.",
		     host);
	  pam_syslog (pamh, LOG_ERR, "Host key for %s changed.", host);
	  send_pam_msg (pamh, PAM_TEXT_INFO,
			"Server's host key has changed.");
	  send_pam_msg (pamh, PAM_TEXT_INFO,
			"Possible man-in-the-middle attack. Connection terminated.");
	  goto fail;
	}
    }


  /*
   * Find out what methods the ssh server supports for authentication.
   */

  userauthlist = libssh2_userauth_list (session, username, strlen (username));

  /*
   * List auth methods that have been returned.
   */

  pam_debug (pamh, "Authentication methods supported: %s", userauthlist);

  if (strstr (userauthlist, "password") != NULL)
    {
      method |= SSH_AUTH_METHOD_PASSWORD;
    }
  if (strstr (userauthlist, "keyboard-interactive") != NULL)
    {
      method |= SSH_AUTH_METHOD_INTERACTIVE;
    }

  do
    {
      /*
       * Try keyboard interactive next, if supported.
       */

      if (method & SSH_AUTH_METHOD_INTERACTIVE)
	{
	  /*
	   * SSH_AUTH_METHOD_INTERACTIVE requires
	   * ChallengeResponseAuthentication to be set to "yes"
	   * in /etc/ssh/sshd_config.
	   */
	  pam_debug (pamh, "Trying keyboard interactive authentication.");
	  ssh_result =
	    libssh2_userauth_keyboard_interactive (session, username,
						   &kbd_callback);
	  if (ssh_result == SSH_AUTH_SUCCESS)
	    {
	      break;
	    }
	}

      /*
       * Finally, plain password authentication.
       */

      if (method & SSH_AUTH_METHOD_PASSWORD)
	{
	  /*
	   * SSH_AUTH_METHOD_PASSWORD is simpler, and
	   * won't handle password expiry.
	   */
	  pam_debug (pamh, "Trying simple password authentication.");
	  ssh_result = auth_pw (pamh, username, session);
	  if (ssh_result == SSH_AUTH_SUCCESS)
	    {
	      break;
	    }
	}

      count--;
    }
  while (count);

fail:

  if (nh != NULL)
    {
      libssh2_knownhost_free (nh);
    }

  if (session != NULL)
    {
      libssh2_session_disconnect (session, "Shutdown");
      libssh2_session_free (session);
      libssh2_exit ();
    }

  if (sockfd)
    {
      close (sockfd);
    }

  if (ssh_result == SSH_AUTH_SUCCESS)
    {
      pam_debug (pamh, "Authenticated successfully.");
      return PAM_SUCCESS;
    }
  else
    {
      pam_debug (pamh, "Authentication failed.");
      return PAM_AUTH_ERR;
    }
}
