Skip to content

Commit

Permalink
Fix for SAML Object serialization which prevented storage of sessions…
Browse files Browse the repository at this point in the history
… between server restarts.
  • Loading branch information
vschafer committed Mar 21, 2011
1 parent c153bdf commit ceab3ef
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import org.opensaml.xml.parse.ParserPool;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.access.BootstrapException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.security.saml.parser.ParserPoolHolder;
Expand All @@ -15,16 +14,20 @@
*/
public class SAMLBootstrap implements BeanFactoryPostProcessor {

@Autowired
ParserPool parserPool;

/**
* Automatically called to initialize whole module. Localizes parserPool from the factory and stores it.
*
* @param beanFactory bean factory
* @throws BeansException errors
*/
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
try {
DefaultBootstrap.bootstrap();
new ParserPoolHolder(parserPool);
ParserPool pool = beanFactory.getBean(ParserPool.class);
new ParserPoolHolder(pool);
} catch (ConfigurationException e) {
throw new BootstrapException("Error invoking OpenSAML bootrap", e);
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
*/
package org.springframework.security.saml.parser;

import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.parse.ParserPool;

/**
Expand All @@ -28,7 +27,7 @@ public class ParserPoolHolder {
/**
* Pool instance.
*/
private static ParserPool pool = new BasicParserPool();
private static ParserPool pool;

/**
* Initializes the static parserPool property and makes it available for getPool calls.
Expand All @@ -46,4 +45,5 @@ public ParserPoolHolder(ParserPool pool) {
public static ParserPool getPool() {
return pool;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@ public SAMLCollection(List<T> object) {
super(object);
}

@Override
public List<T> getObject() {
if (object == null) { // Lazy parse
parse();
}
return super.getObject();
}

/**
* Custom serialization logic which transform List of XMLObject into List of Strings.
*
* @param out output stream
*
* @throws java.io.IOException error performing XMLObject serialization
*/
private void writeObject(ObjectOutputStream out) throws IOException {
Expand All @@ -63,30 +70,40 @@ private void writeObject(ObjectOutputStream out) throws IOException {
out.writeObject(serializedObject);
} catch (MessageEncodingException e) {
log.error("Error serializing SAML object", e);
throw new IOException("Error serializing SAML object: " + e.getMessage());
throw new IOException("Error serializing SAML object: " + e.getMessage());
}
}

/**
* Deserializes List of XMLObjects from the stream.
* Deserializes List of XMLObjects from the stream. Parsing of the content is done lazily upon access
* to the object. The reason for this is the fact that parser pool may not be initialized during system startup
* and the object may be stored in a serialized session.
*
* @param in input stream containing XMLObject as String
*
* @throws IOException error deserializing String to XMLObject
* @throws ClassNotFoundException class not found
*/
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
this.serializedObject = (ArrayList<String>) in.readObject();
}

/**
* Lazily parsers serialized data.
*/
private void parse() {
try {
ArrayList<String> serializedItems = (ArrayList<String>) in.readObject();
List<T> items = new LinkedList<T>();
for (String item : serializedItems) {
items.add(unmarshallMessage(new StringReader(item)));
ArrayList<String> serializedItems = (ArrayList<String>) serializedObject;
if (serializedItems != null) {
List<T> items = new LinkedList<T>();
for (String item : serializedItems) {
items.add(unmarshallMessage(new StringReader(item)));
}
object = items;
}
this.serializedObject = serializedItems;
object = items;
} catch (MessageDecodingException e) {
log.error("Error de-serializing SAML object", e);
throw new IOException("Error de-serializing SAML object: " + e.getMessage());
throw new RuntimeException("Error de-serializing SAML object: " + e.getMessage());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
* SAMLObject is a wrapper around XMLObject instances of OpenSAML library As some XMLObjects are stored
* inside the HttpSession (which could be potentially sent to another cluster members), we need
* mechanism to enable serialization of these instances.
*
* @author Vladimir Schafer
* @param <T> type of XMLObject
* @author Vladimir Schafer
*/
public class SAMLObject<T extends XMLObject> extends SAMLBase<T, T> {

Expand All @@ -45,14 +48,16 @@ public SAMLObject(T object) {

@Override
public T getObject() {
if (object == null) { // Lazy parse
parse();
}
return super.getObject();
}

/**
* Custom serialization logic which transform XMLObject into String.
*
* @param out output stream
*
* @throws java.io.IOException error performing XMLObject serialization
*/
private void writeObject(ObjectOutputStream out) throws IOException {
Expand All @@ -68,20 +73,30 @@ private void writeObject(ObjectOutputStream out) throws IOException {
}

/**
* Deserializes XMLObject from the stream.
* Deserializes XMLObject from the stream. Parsing of the content is done lazily upon access
* to the object. The reason for this is the fact that parser pool may not be initialized during system startup
* and the object may be stored in a serialized session.
*
* @param in input stream contaiing XMLObject as String
*
* @throws IOException error deserializing String to XMLObject
* @throws ClassNotFoundException class not found
*/
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
serializedObject = in.readUTF();
}

/**
* Lazily parsers serialized data.
*/
private void parse() {
try {
serializedObject = in.readUTF();
object = unmarshallMessage(new StringReader((String) serializedObject));
if (serializedObject != null) {
object = unmarshallMessage(new StringReader((String) serializedObject));
}
} catch (MessageDecodingException e) {
log.error("Error de-serializing SAML object", e);
throw new IOException("Error de-serializing SAML object: " + e.getMessage());
throw new RuntimeException("Error de-serializing SAML object: " + e.getMessage());
}
}

}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2009 Vladimir Schäfer
/* Copyright 2009 Vladimir Schaefer
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,7 @@
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.io.Unmarshaller;
import org.opensaml.xml.io.UnmarshallingException;
import org.opensaml.xml.parse.ParserPool;
import org.springframework.security.saml.SAMLTestBase;
import org.w3c.dom.Element;

Expand All @@ -34,7 +35,7 @@
import static org.easymock.EasyMock.*;

/**
* @author Vladimir Schäfer
* @author Vladimir Schaefer
*/
public class SAMLObjectTest extends SAMLTestBase {

Expand Down Expand Up @@ -70,7 +71,7 @@ public void testNoNullArgument() {
*
* @throws Exception error
*/
@Test(expected = IOException.class)
@Test(expected = RuntimeException.class)
public void testMarshalWithoutPoolSet() throws Exception {
new ParserPoolHolder(null);
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
Expand All @@ -79,7 +80,34 @@ public void testMarshalWithoutPoolSet() throws Exception {

ByteArrayInputStream inputStream = new ByteArrayInputStream(outStream.toByteArray());
ObjectInputStream input = new ObjectInputStream(inputStream);
input.readObject();
SAMLBase o = (SAMLBase) input.readObject();
o.getObject();

}

/**
* Verifies that deserializaion succeeds when parser pool is set when the object is accessed.
*
* @throws Exception error
*/
@Test
public void testMarshalWithLazyPoolSet() throws Exception {

ParserPool pool = ParserPoolHolder.getPool();
new ParserPoolHolder(null);

ByteArrayOutputStream outStream = new ByteArrayOutputStream();
ObjectOutputStream stream = new ObjectOutputStream(outStream);
stream.writeObject(assertionObject);

ByteArrayInputStream inputStream = new ByteArrayInputStream(outStream.toByteArray());
ObjectInputStream input = new ObjectInputStream(inputStream);
SAMLBase o = (SAMLBase) input.readObject();

new ParserPoolHolder(pool);

o.getObject();

}

/**
Expand Down Expand Up @@ -123,7 +151,7 @@ public void testMarshallingError() throws Exception {
*
* @throws Exception error
*/
@Test(expected = IOException.class)
@Test(expected = RuntimeException.class)
public void testNoUnmarshaller() throws Exception {

ByteArrayOutputStream outStream = new ByteArrayOutputStream();
Expand All @@ -137,10 +165,12 @@ public void testNoUnmarshaller() throws Exception {

try {
Configuration.getUnmarshallerFactory().deregisterUnmarshaller(assertion.getElementQName());
input.readObject();
SAMLBase o = (SAMLBase) input.readObject();
o.getObject();
} finally {
Configuration.getUnmarshallerFactory().registerUnmarshaller(assertion.getElementQName(), old);
}

}

class TestObject extends ActionImpl {
Expand All @@ -154,7 +184,7 @@ class TestObject extends ActionImpl {
*
* @throws Exception error
*/
@Test(expected = IOException.class)
@Test(expected = RuntimeException.class)
public void testWrongXMLInStream() throws Exception {

ByteArrayOutputStream outStream = new ByteArrayOutputStream();
Expand All @@ -172,7 +202,8 @@ public void testWrongXMLInStream() throws Exception {

try {
replay(mock);
input.readObject();
SAMLBase o = (SAMLBase) input.readObject();
o.getObject();
verify(mock);
} finally {
Configuration.getUnmarshallerFactory().registerUnmarshaller(assertion.getElementQName(), old);
Expand Down

0 comments on commit ceab3ef

Please sign in to comment.